From 8c29b37d7405ac37aab537a8071b7fa9b7d01dc3 Mon Sep 17 00:00:00 2001 From: "Field G. Van Zee" Date: Mon, 20 May 2024 15:02:47 -0500 Subject: [PATCH] Tile-level partitioning in jr/ir loops (ex-trsm). (#695) Details: - Reimplemented parallelization of the JR loop in gemmt (which is recycled for herk, her2k, syrk, and syr2k). Previously, the rectangular region of the current MC x NC panel of C would be parallelized separately from from the diagonal region of that same submatrix, with the rectangular portion being assigned to threads via slab or round-robin (rr) partitioning (as determined at configure- time) and the diagonal region being assigned via round-robin. This approach did not work well when extracting lots of parallelism from the JR loop and was often suboptimal even for smaller degrees of parallelism. This commit implements tile-level load balancing (tlb) in which the IR loop is effectively subjugated in service of more equitably dividing work in the JR loop. This approach is especially potent for certain situations where the diagonal region of the MC x NR panel of C are significant relative to the entire region. However, it also seems to benefit many problem sizes of other level-3 operations (excluding trsm, which has an inherent algorithmic dependency in the IR loop that prevents the application of tlb). For now, tlb is implemented as _var2b.c macrokernels for gemm (which forms the basis for gemm, hemm, and symm), gemmt (which forms the basis of herk, her2k, syrk, and syr2k), and trmm (which forms the basis of trmm and trmm3). Which function pointers (_var2() or _var2b()) are embedded in the control tree will depend on whether the BLIS_ENABLE_JRIR_TLB cpp macro is defined, which is controlled by the value passed to the existing --thread-part-jrir=METHOD (or -r METHOD) configure option. This script adds 'tlb' as a valid option alongside the previously supported values of 'slab' and 'rr'. ('slab' is still the default.) Thanks to Leick Robinson for abstractly inspiring this work, and to Minh Quan Ho for inquiring (in PR #562, and before that in Issue #437) about the possibility of improved load balance in macrokernel loops, and even prototyping what it might look like, long before I fully understood the problem. - In bli_thread_range_weighted_sub(), tweaked the the way we compute the area of the current MC x NC trapezoidal panel of C by better taking into account the microtile structure along the diagonal. Previously, it was an underestimate, as it assumed MR = NR = 1 (that is, it assumed that the microtile column of C that overlapped with microtiles exactly coincided with the diagonal). Now, we only assume MR = NR. This is still a slight underestimate when MR != NR, so the additional area is scaled by 1.5 in a hackish attempt to compensate for this, as well as other additional effects that are difficult to model (such as the increased cost of writing to temporary tiles before finally updating C). The net effect of this better estimation of the trapezoidal area should be (on average) slightly larger regions assigned to threads that have little or no overlap with the diagonal region (and correspondingly slightly smaller regions in the diagonal region), which we expect will lead to slightly better load balancing in most situations. - Spun off the contents of bli_thread.[ch] that relate to computing thread ranges into one of three source/header file pairs: - bli_thread_range.[ch], which define functions that are not specific to the jr/ir loops; - bli_thread_range_slab_rr.[ch], which define functions that implement slab or round-robin partitioning for the jr/ir loops; - bli_thread_range_tlb.[ch], which define functions that implement tlb for the jr/ir loops. - Fixed the computation of a_next in the last iteration of the IR loop in bli_gemmt_l_ker_var2(). Previously, it always "wrapped" back around to the first micropanel of the current MC x KC packed block of A. However, this is almost never actually the micropanel that is used next. A new macro, bli_gemmt_l_wrap_a_upanel(), computes a_next correctly, with a similarly named bli_gemmt_u_wrap_a_upanel() for use in the upper-stored case (which *does* actually always choose the first micropanel of A as its a_next at the end of the IR loop). - Removed adjustments for a_next/b_next (a2/b2) for the diagonal- intersecting case of gemmt_l_ker_var2() and the above-diagonal case of gemmt_u_ker_var2() since these cases will only coincide with the last iteration of the IR loop in very small problems. - Defined bli_is_last_iter_l() and bli_is_last_iter_u(), the latter of which explicitly considers whether the current microtile is the last tile that intersects the diagonal. (The former does the same, but the computation coincides with the original bli_is_last_iter().) These functions are now used in gemmt to test when a_next (or a2) should "wrap" (as discussed above). Also defined bli_is_last_iter_tlb_l() and bli_is_last_iter_tlb_u(), which are similar to the aforementioned functions but are used when employing tlb in gemmt. - Redefined macros in bli_packm_thrinfo.h, which test whether an iteration of work is assigned to a thread, as static inline functions in bli_param_macro_defs.h (and then deleted bli_packm_thrinfo.h). In the process of redefining these macros, I also renamed them from bli_packm_my_iter_rr/sl() to bli_is_my_iter_rr/sl(). - Renamed bli_thread_range_jrir_rr() -> bli_thread_range_rr() bli_thread_range_jrir_sl() -> bli_thread_range_sl() bli_thread_range_jrir() -> bli_thread_range_slrr() - Renamed bli_is_last_iter() -> bli_is_last_iter_slrr() - Defined bli_info_get_thread_jrir_tlb() and renamed: - bli_info_get_thread_part_jrir_slab() -> bli_info_get_thread_jrir_slab() - bli_info_get_thread_part_jrir_rr() -> bli_info_get_thread_jrir_rr() - Modified bli_rntm_set_ways_for_op() to redirect IR loop parallelism into the JR loop when tlb is enabled for non-trsm level-3 operations. - Added a sanity check to prevent bli_prune_unref_mparts() from being used on packed objects. This prohibition is necessary because the current implementation does not take into account the atomicity of packed micropanel widths relative to the diagonal of structured matrices. That is, the function prunes greedily without regard to whether doing so would prune off part of a micropanel *which has already been packed* and assigned to a thread for inclusion in the computation. - Further restricted early returns in bli_prune_unref_mparts() to situations where the primary matrix is not only of general structure but also dense (in terms of its uplo_t value). The addition of the matrix's dense-ness to the conditional is required because gemmt is somewhat unusual in that its C matrix has general structure but is marked as lower- or upper-stored via its uplo_t. By only checking for general structure, attempts to prune gemmt C matrices would incorrectly result in early returns, even though that operation effectively treats the matrix as symmetric (and stored in only one triangle). - Fixed a latent bug in bli_thread_range_rr() wherein incorrect ranges were computed when 1 < bf. Thankfully, this bug was not yet manifesting since all current invocations used bf == 1. - Fixed a latent bug in some unexercised code in bli_?gemmt_l_ker_var2() that would perform incorrect pruning of unreferenced regions above where the diagonal of a lower-stored matrix intersects the right edge. Thankfully, the bug was not harming anything since those unreferenced regions were being pruned prior to the macrokernel. - Rewrote slab/rr-based gemmt macrokernels so that they no longer carved C into rectangular and diagonal regions prior to parallelizing each separately. The new macrokernels use a unified loop structure where quadratic (slab) partitioning is used. - Updated all level-3 macrokernels to have a more uniform coding style, such as wrt combining variable declarations with initializations as well as the use of const. - Updated bls_l3_packm_var[123].c to use bli_thrinfo_n_way() and bli_thrinfo_work_id() instead of bli_thrinfo_num_threads() and bli_thrinfo_thread_id(), respectively. This change probably should have been included in aeb5f0c. - Removed old prototypes in bli_gemmt_var.h and bli_trmm_var.h that corresponded to functions that were removed in aeb5f0c. - Other very minor cleanups. - Comment updates. - (cherry picked from commit 2e1ba9d13c23a06a7b6f8bd326af428f7ea68c31) --- build/bli_config.h.in | 4 + configure | 61 +- frame/1m/packm/bli_packm.h | 1 - frame/1m/packm/bli_packm_blk_var1.c | 16 +- frame/3/bli_l3_sup_packm_var.c | 16 +- frame/3/bli_l3_sup_var12.c | 4 +- frame/3/bli_l3_thrinfo.h | 12 +- frame/3/gemm/bli_gemm_cntl.c | 23 +- frame/3/gemm/bli_gemm_ker_var2.c | 93 +- frame/3/gemm/bli_gemm_ker_var2b.c | 379 ++++ frame/3/gemm/bli_gemm_var.h | 3 +- frame/3/gemmt/attic/bli_gemmt_l_ker_var2b.c | 429 +++++ frame/3/gemmt/attic/bli_gemmt_u_ker_var2b.c | 418 ++++ frame/3/gemmt/bli_gemmt_l_ker_var2.c | 274 +-- frame/3/gemmt/bli_gemmt_l_ker_var2b.c | 387 ++++ frame/3/gemmt/bli_gemmt_u_ker_var2.c | 273 +-- frame/3/gemmt/bli_gemmt_u_ker_var2b.c | 386 ++++ frame/3/gemmt/bli_gemmt_var.h | 45 +- frame/3/gemmt/bli_gemmt_x_ker_var2.c | 14 +- frame/3/gemmt/bli_gemmt_x_ker_var2b.c | 73 + .../3/gemmt/other/bli_gemmt_l_ker_var2.c.prev | 507 +++++ .../other/bli_gemmt_l_ker_var2b.c.before | 427 +++++ .../3/gemmt/other/bli_gemmt_u_ker_var2.c.prev | 510 +++++ .../other/bli_gemmt_u_ker_var2b.c.before | 415 ++++ frame/3/trmm/bli_trmm_ll_ker_var2.c | 99 +- frame/3/trmm/bli_trmm_ll_ker_var2b.c | 365 ++++ frame/3/trmm/bli_trmm_lu_ker_var2.c | 99 +- frame/3/trmm/bli_trmm_lu_ker_var2b.c | 366 ++++ frame/3/trmm/bli_trmm_rl_ker_var2.c | 75 +- frame/3/trmm/bli_trmm_rl_ker_var2b.c | 392 ++++ frame/3/trmm/bli_trmm_ru_ker_var2.c | 77 +- frame/3/trmm/bli_trmm_ru_ker_var2b.c | 390 ++++ frame/3/trmm/bli_trmm_var.h | 53 +- frame/3/trmm/bli_trmm_xx_ker_var2.c | 14 +- frame/3/trmm/bli_trmm_xx_ker_var2b.c | 87 + .../3/trmm/other/bli_trmm_rl_ker_var2.c.prev | 371 ++++ .../trmm/other/bli_trmm_rl_ker_var2.c.unified | 324 ++++ frame/3/trmm/other/bli_trmm_ru_ker_var2.c | 2 +- frame/3/trsm/bli_trsm_ll_ker_var2.c | 65 +- frame/3/trsm/bli_trsm_lu_ker_var2.c | 69 +- frame/3/trsm/bli_trsm_rl_ker_var2.c | 143 +- frame/3/trsm/bli_trsm_ru_ker_var2.c | 12 +- frame/3/trsm/bli_trsm_var.h | 2 +- frame/3/trsm/bli_trsm_xx_ker_var2.c | 14 +- frame/base/bli_info.c | 12 +- frame/base/bli_info.h | 5 +- frame/base/bli_prune.c | 39 +- frame/base/bli_rntm.c | 40 +- frame/include/bli_config_macro_defs.h | 10 + frame/include/bli_kernel_macro_defs.h | 2 + frame/include/bli_param_macro_defs.h | 51 +- frame/include/blis.h | 15 + frame/thread/bli_thread.c | 901 --------- frame/thread/bli_thread.h | 180 +- frame/thread/bli_thread_range.c | 1121 +++++++++++ frame/thread/bli_thread_range.h | 128 ++ frame/thread/bli_thread_range_slab_rr.c | 134 ++ frame/thread/bli_thread_range_slab_rr.h | 116 ++ frame/thread/bli_thread_range_tlb.c | 1699 +++++++++++++++++ frame/thread/bli_thread_range_tlb.h | 192 ++ frame/thread/old/bli_thread_range_snake.c | 120 ++ .../old/bli_thread_range_snake.h} | 46 +- sandbox/gemmlike/bls_gemm_bp_var1.c | 4 +- sandbox/gemmlike/bls_l3_packm_var1.c | 8 +- sandbox/gemmlike/bls_l3_packm_var2.c | 8 +- testsuite/src/test_libblis.c | 7 +- testsuite/src/test_trmm.c | 3 + 67 files changed, 10608 insertions(+), 2022 deletions(-) create mode 100644 frame/3/gemm/bli_gemm_ker_var2b.c create mode 100644 frame/3/gemmt/attic/bli_gemmt_l_ker_var2b.c create mode 100644 frame/3/gemmt/attic/bli_gemmt_u_ker_var2b.c create mode 100644 frame/3/gemmt/bli_gemmt_l_ker_var2b.c create mode 100644 frame/3/gemmt/bli_gemmt_u_ker_var2b.c create mode 100644 frame/3/gemmt/bli_gemmt_x_ker_var2b.c create mode 100644 frame/3/gemmt/other/bli_gemmt_l_ker_var2.c.prev create mode 100644 frame/3/gemmt/other/bli_gemmt_l_ker_var2b.c.before create mode 100644 frame/3/gemmt/other/bli_gemmt_u_ker_var2.c.prev create mode 100644 frame/3/gemmt/other/bli_gemmt_u_ker_var2b.c.before create mode 100644 frame/3/trmm/bli_trmm_ll_ker_var2b.c create mode 100644 frame/3/trmm/bli_trmm_lu_ker_var2b.c create mode 100644 frame/3/trmm/bli_trmm_rl_ker_var2b.c create mode 100644 frame/3/trmm/bli_trmm_ru_ker_var2b.c create mode 100644 frame/3/trmm/bli_trmm_xx_ker_var2b.c create mode 100644 frame/3/trmm/other/bli_trmm_rl_ker_var2.c.prev create mode 100644 frame/3/trmm/other/bli_trmm_rl_ker_var2.c.unified create mode 100644 frame/thread/bli_thread_range.c create mode 100644 frame/thread/bli_thread_range.h create mode 100644 frame/thread/bli_thread_range_slab_rr.c create mode 100644 frame/thread/bli_thread_range_slab_rr.h create mode 100644 frame/thread/bli_thread_range_tlb.c create mode 100644 frame/thread/bli_thread_range_tlb.h create mode 100644 frame/thread/old/bli_thread_range_snake.c rename frame/{1m/packm/bli_packm_thrinfo.h => thread/old/bli_thread_range_snake.h} (70%) diff --git a/build/bli_config.h.in b/build/bli_config.h.in index 41e76d2144..7dc67059f8 100644 --- a/build/bli_config.h.in +++ b/build/bli_config.h.in @@ -80,6 +80,10 @@ #define BLIS_ENABLE_JRIR_RR #endif +#if @enable_jrir_tlb@ +#define BLIS_ENABLE_JRIR_TLB +#endif + #if @enable_pba_pools@ #define BLIS_ENABLE_PBA_POOLS #else diff --git a/configure b/configure index 286a66123c..06201b4fa9 100755 --- a/configure +++ b/configure @@ -340,16 +340,36 @@ print_usage() echo " " echo " -r METHOD, --thread-part-jrir=METHOD" echo " " - echo " Request a method of assigning micropanels to threads in" - echo " the JR and IR loops. Valid values for METHOD are 'slab'" - echo " and 'rr'. Using 'slab' assigns (as much as possible)" - echo " contiguous regions of micropanels to each thread while" - echo " using 'rr' assigns micropanels to threads in a round-" - echo " robin fashion. The chosen method also applies during" - echo " the packing of A and B. The default method is 'slab'." - echo " NOTE: Specifying this option constitutes a request," - echo " which may be ignored in select situations if the" - echo " implementation has a good reason to do so." + echo " Select a strategy for partitioning computation in JR and" + echo " IR loops and assigning that computation to threads. Valid" + echo " values for METHOD are 'rr', 'slab', and 'tlb':" + echo " 'rr': Assign the computation associated with whole" + echo " columns of microtiles to threads in a round-" + echo " robin fashion. When selected, round-robin" + echo " assignment is also employed during packing." + echo " 'slab': Partition the computation into N contiguous" + echo " regions, where each region contains a whole" + echo " number of microtile columns, and assign one" + echo " region to each thread. For some operations, the" + echo " number of microtile columns contained within a" + echo " given region may differ from that of other" + echo " regions, depending on how much work is implied" + echo " by each region. When selected, slab assignment" + echo " is also employed during packing." + echo " 'tlb': Tile-level load balancing is similar to slab," + echo " except that regions will be divided at a more" + echo " granular level (individual microtiles instead" + echo " of whole columns of microtiles) to ensure more" + echo " equitable assignment of work to threads. When" + echo " selected, tlb will only be employed for level-3" + echo " operations except trsm; due to practical and" + echo " algorithmic limitations, slab partitioning will" + echo " be used instead during packing and for trsm." + echo " The default strategy is 'slab'. NOTE: Specifying this" + echo " option constitutes a request, which may be ignored in" + echo " select situations if implementation has a good reason to" + echo " do so. (See description of 'tlb' above for an example of" + echo " this.)" echo " " echo " --disable-trsm-preinversion, --enable-trsm-preinversion" echo " " @@ -3731,16 +3751,20 @@ main() # Check the method of assigning micropanels to threads in the JR and IR # loops. - enable_jrir_slab_01=0 enable_jrir_rr_01=0 - if [ "x${thread_part_jrir}" = "xslab" ]; then - echo "${script_name}: requesting slab threading in jr and ir loops." - enable_jrir_slab_01=1 - elif [ "x${thread_part_jrir}" = "xrr" ]; then - echo "${script_name}: requesting round-robin threading in jr and ir loops." + enable_jrir_slab_01=0 + enable_jrir_tlb_01=0 + if [ "x${thread_part_jrir}" = "xrr" ]; then + echo "${script_name}: requesting round-robin (rr) work partitioning in jr and/or ir loops." enable_jrir_rr_01=1 + elif [ "x${thread_part_jrir}" = "xslab" ]; then + echo "${script_name}: requesting slab work partitioning in jr and/or ir loops." + enable_jrir_slab_01=1 + elif [ "x${thread_part_jrir}" = "xtlb" ]; then + echo "${script_name}: requesting tile-level load balancing (tlb) in unified jr+ir loop." + enable_jrir_tlb_01=1 else - echo "${script_name}: *** Unsupported method of thread partitioning in jr and ir loops: ${thread_part_jrir}." + echo "${script_name}: *** Unsupported method of work partitioning in jr/ir loops: ${thread_part_jrir}." exit 1 fi @@ -4177,8 +4201,9 @@ main() | sed -e "s/@enable_pthreads_as_def@/${enable_pthreads_as_def_01}/g" \ | sed -e "s/@enable_hpx@/${enable_hpx_01}/g" \ | sed -e "s/@enable_hpx_as_def@/${enable_hpx_as_def_01}/g" \ - | sed -e "s/@enable_jrir_slab@/${enable_jrir_slab_01}/g" \ | sed -e "s/@enable_jrir_rr@/${enable_jrir_rr_01}/g" \ + | sed -e "s/@enable_jrir_slab@/${enable_jrir_slab_01}/g" \ + | sed -e "s/@enable_jrir_tlb@/${enable_jrir_tlb_01}/g" \ | sed -e "s/@enable_pba_pools@/${enable_pba_pools_01}/g" \ | sed -e "s/@enable_sba_pools@/${enable_sba_pools_01}/g" \ | sed -e "s/@enable_mem_tracing@/${enable_mem_tracing_01}/g" \ diff --git a/frame/1m/packm/bli_packm.h b/frame/1m/packm/bli_packm.h index 80878fba01..7d73bf903e 100644 --- a/frame/1m/packm/bli_packm.h +++ b/frame/1m/packm/bli_packm.h @@ -39,7 +39,6 @@ #include "bli_packm_init.h" #include "bli_packm_int.h" #include "bli_packm_scalar.h" -#include "bli_packm_thrinfo.h" #include "bli_packm_part.h" diff --git a/frame/1m/packm/bli_packm_blk_var1.c b/frame/1m/packm/bli_packm_blk_var1.c index b8f4f945d9..561988e7f7 100644 --- a/frame/1m/packm/bli_packm_blk_var1.c +++ b/frame/1m/packm/bli_packm_blk_var1.c @@ -170,11 +170,11 @@ void bli_packm_blk_var1 const dim_t tid = bli_thrinfo_work_id( thread ); // Determine the thread range and increment using the current thread's - // packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir() + // packm thrinfo_t node. NOTE: The definition of bli_thread_range_slrr() // will depend on whether slab or round-robin partitioning was requested // at configure-time. dim_t it_start, it_end, it_inc; - bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); + bli_thread_range_slrr( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); char* p_begin = p_cast; @@ -195,10 +195,10 @@ void bli_packm_blk_var1 char* c_begin = c_cast + (ic )*incc*dt_c_size; - // Hermitian/symmetric and general packing may use slab or - // round-robin (bli_packm_my_iter()), depending on which was - // selected at configure-time. - if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) + // Hermitian/symmetric and general packing may use slab or round- + // robin (bli_is_my_iter()), depending on which was selected at + // configure-time. + if ( bli_is_my_iter( it, it_start, it_end, tid, nt ) ) { packm_ker_cast( bli_is_triangular( strucc ) ? BLIS_GENERAL : strucc, diagc, @@ -286,9 +286,9 @@ void bli_packm_blk_var1 // We nudge the imaginary stride up by one if it is odd. is_p_use += ( bli_is_odd( is_p_use ) ? 1 : 0 ); - // NOTE: We MUST use round-robin work allocation (bli_packm_my_iter_rr()) + // NOTE: We MUST use round-robin work allocation (bli_is_my_iter_rr()) // when packing micropanels of a triangular matrix. - if ( bli_packm_my_iter_rr( it, it_start, it_end, tid, nt ) ) + if ( bli_is_my_iter_rr( it, tid, nt ) ) { packm_ker_cast( strucc, diagc, diff --git a/frame/3/bli_l3_sup_packm_var.c b/frame/3/bli_l3_sup_packm_var.c index e47f65aeaf..67b33f407c 100644 --- a/frame/3/bli_l3_sup_packm_var.c +++ b/frame/3/bli_l3_sup_packm_var.c @@ -155,10 +155,10 @@ void PASTEMAC(ch,varname) \ dim_t it_start, it_end, it_inc; \ \ /* Determine the thread range and increment using the current thread's - packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir() + packm thrinfo_t node. NOTE: The definition of bli_thread_range_slrr() will depend on whether slab or round-robin partitioning was requested at configure-time. */ \ - bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); \ + bli_thread_range_slrr( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); \ \ /* Iterate over every logical micropanel in the source matrix. */ \ for ( ic = ic0, it = 0; it < n_iter; \ @@ -175,9 +175,9 @@ void PASTEMAC(ch,varname) \ panel_len_i = panel_len_full; \ panel_len_max_i = panel_len_max; \ \ - /* The definition of bli_packm_my_iter() will depend on whether slab + /* The definition of bli_is_my_iter() will depend on whether slab or round-robin partitioning was requested at configure-time. */ \ - if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) \ + if ( bli_is_my_iter( it, it_start, it_end, tid, nt ) ) \ { \ f \ ( \ @@ -398,10 +398,10 @@ void PASTEMAC(ch,varname) \ dim_t it_start, it_end, it_inc; \ \ /* Determine the thread range and increment using the current thread's - packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir() + packm thrinfo_t node. NOTE: The definition of bli_thread_range_slrr() will depend on whether slab or round-robin partitioning was requested at configure-time. */ \ - bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); \ + bli_thread_range_slrr( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); \ \ /* Iterate over every logical micropanel in the source matrix. */ \ for ( it = 0; it < n_iter; it += 1 ) \ @@ -412,9 +412,9 @@ void PASTEMAC(ch,varname) \ ctype* p_use = p_begin; \ \ { \ - /* The definition of bli_packm_my_iter() will depend on whether slab + /* The definition of bli_is_my_iter() will depend on whether slab or round-robin partitioning was requested at configure-time. */ \ - if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) \ + if ( bli_is_my_iter( it, it_start, it_end, tid, nt ) ) \ { \ PASTEMAC2(ch,scal2v,BLIS_TAPI_EX_SUF) \ ( \ diff --git a/frame/3/bli_l3_sup_var12.c b/frame/3/bli_l3_sup_var12.c index d65482243b..4162c3d33e 100644 --- a/frame/3/bli_l3_sup_var12.c +++ b/frame/3/bli_l3_sup_var12.c @@ -357,11 +357,11 @@ void PASTEMAC(ch,varname) \ object. */ \ /* ctype* a2 = bli_gemm_get_next_a_upanel( a_ir, irstep_a, ir_inc ); \ - if ( bli_is_last_iter( i, ir_iter, 0, 1 ) ) \ + if ( bli_is_last_iter_slrr( i, ir_iter, 0, 1 ) ) \ { \ a2 = a_00; \ b2 = bli_gemm_get_next_b_upanel( b_jr, jrstep_b, jr_inc ); \ - if ( bli_is_last_iter( j, jr_iter, 0, 1 ) ) \ + if ( bli_is_last_iter_slrr( j, jr_iter, 0, 1 ) ) \ b2 = b_00; \ } \ \ diff --git a/frame/3/bli_l3_thrinfo.h b/frame/3/bli_l3_thrinfo.h index b1290df508..2ea7a3fc23 100644 --- a/frame/3/bli_l3_thrinfo.h +++ b/frame/3/bli_l3_thrinfo.h @@ -39,22 +39,22 @@ // gemm -// NOTE: The definition of bli_gemm_get_next_?_upanel() does not need to -// change depending on BLIS_ENABLE_JRIR_SLAB / BLIS_ENABLE_JRIR_RR. #define bli_gemm_get_next_a_upanel( a1, step, inc ) ( a1 + step * inc ) #define bli_gemm_get_next_b_upanel( b1, step, inc ) ( b1 + step * inc ) // gemmt -// NOTE: The definition of bli_gemmt_get_next_?_upanel() does not need to -// change depending on BLIS_ENABLE_JRIR_SLAB / BLIS_ENABLE_JRIR_RR. #define bli_gemmt_get_next_a_upanel( a1, step, inc ) ( a1 + step * inc ) #define bli_gemmt_get_next_b_upanel( b1, step, inc ) ( b1 + step * inc ) +// NOTE: Here, we assume NO parallelism in the IR loop. +#define bli_gemmt_l_wrap_a_upanel( a0, step, doff_j, mr, nr ) \ + ( a0 + ( (-doff_j + 1*nr) / mr ) * step ) +#define bli_gemmt_u_wrap_a_upanel( a0, step, doff_j, mr, nr ) \ + ( a0 ) + // trmm -// NOTE: The definition of bli_trmm_get_next_?_upanel() does not need to -// change depending on BLIS_ENABLE_JRIR_SLAB / BLIS_ENABLE_JRIR_RR. #define bli_trmm_get_next_a_upanel( a1, step, inc ) ( a1 + step * inc ) #define bli_trmm_get_next_b_upanel( b1, step, inc ) ( b1 + step * inc ) diff --git a/frame/3/gemm/bli_gemm_cntl.c b/frame/3/gemm/bli_gemm_cntl.c index bd8d97d13d..b9c231cf72 100644 --- a/frame/3/gemm/bli_gemm_cntl.c +++ b/frame/3/gemm/bli_gemm_cntl.c @@ -61,10 +61,25 @@ cntl_t* bli_gemmbp_cntl_create void_fp macro_kernel_fp; // Choose the default macrokernel based on the operation family... - if ( family == BLIS_GEMM ) macro_kernel_fp = bli_gemm_ker_var2; - else if ( family == BLIS_GEMMT ) macro_kernel_fp = bli_gemmt_x_ker_var2; - else if ( family == BLIS_TRMM ) macro_kernel_fp = bli_trmm_xx_ker_var2; - else /* should never execute */ macro_kernel_fp = NULL; + if ( family == BLIS_GEMM ) macro_kernel_fp = + #ifdef BLIS_ENABLE_JRIR_TLB + bli_gemm_ker_var2b; + #else // ifdef ( _SLAB || _RR ) + bli_gemm_ker_var2; + #endif + else if ( family == BLIS_GEMMT ) macro_kernel_fp = + #ifdef BLIS_ENABLE_JRIR_TLB + bli_gemmt_x_ker_var2b; + #else // ifdef ( _SLAB || _RR ) + bli_gemmt_x_ker_var2; + #endif + else if ( family == BLIS_TRMM ) macro_kernel_fp = + #ifdef BLIS_ENABLE_JRIR_TLB + bli_trmm_xx_ker_var2b; + #else // ifdef ( _SLAB || _RR ) + bli_trmm_xx_ker_var2; + #endif + else /* should never execute */ macro_kernel_fp = NULL; // ...unless a non-NULL kernel function pointer is passed in, in which // case we use that instead. diff --git a/frame/3/gemm/bli_gemm_ker_var2.c b/frame/3/gemm/bli_gemm_ker_var2.c index d596950819..3e862e6c59 100644 --- a/frame/3/gemm/bli_gemm_ker_var2.c +++ b/frame/3/gemm/bli_gemm_ker_var2.c @@ -47,7 +47,7 @@ typedef void (*xpbys_mxn_vft) #undef GENTFUNC2 #define GENTFUNC2(ctypex,ctypey,chx,chy,op) \ \ -void PASTEMAC2(chx,chy,op) \ +BLIS_INLINE void PASTEMAC2(chx,chy,op) \ ( \ dim_t m, \ dim_t n, \ @@ -77,31 +77,31 @@ static xpbys_mxn_vft GENARRAY2_ALL(xpbys_mxn, xpbys_mxn_fn); void bli_gemm_ker_var2 ( - const obj_t* a, - const obj_t* b, - const obj_t* c, - const cntx_t* cntx, - const cntl_t* cntl, + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntx_t* cntx, + const cntl_t* cntl, thrinfo_t* thread_par ) { num_t dt_exec = bli_obj_exec_dt( c ); num_t dt_c = bli_obj_dt( c ); - pack_t schema_a = bli_obj_pack_schema( a ); - pack_t schema_b = bli_obj_pack_schema( b ); + const pack_t schema_a = bli_obj_pack_schema( a ); + const pack_t schema_b = bli_obj_pack_schema( b ); dim_t m = bli_obj_length( c ); dim_t n = bli_obj_width( c ); dim_t k = bli_obj_width( a ); const char* a_cast = bli_obj_buffer_at_off( a ); - inc_t is_a = bli_obj_imag_stride( a ); + const inc_t is_a = bli_obj_imag_stride( a ); dim_t pd_a = bli_obj_panel_dim( a ); inc_t ps_a = bli_obj_panel_stride( a ); const char* b_cast = bli_obj_buffer_at_off( b ); - inc_t is_b = bli_obj_imag_stride( b ); + const inc_t is_b = bli_obj_imag_stride( b ); dim_t pd_b = bli_obj_panel_dim( b ); inc_t ps_b = bli_obj_panel_stride( b ); @@ -116,8 +116,7 @@ void bli_gemm_ker_var2 // NOTE: We know that the internal scalars of A and B are already of the // target datatypes because the necessary typecasting would have already // taken place during bli_packm_init(). - obj_t scalar_a; - obj_t scalar_b; + obj_t scalar_a, scalar_b; bli_obj_scalar_detach( a, &scalar_a ); bli_obj_scalar_detach( b, &scalar_b ); bli_mulsc( &scalar_a, &scalar_b ); @@ -217,22 +216,19 @@ void bli_gemm_ker_var2 // Compute number of primary and leftover components of the m and n // dimensions. - dim_t n_iter = n / NR; - dim_t n_left = n % NR; + const dim_t n_iter = n / NR + ( n % NR ? 1 : 0 ); + const dim_t n_left = n % NR; - dim_t m_iter = m / MR; - dim_t m_left = m % MR; - - if ( n_left ) ++n_iter; - if ( m_left ) ++m_iter; + const dim_t m_iter = m / MR + ( m % MR ? 1 : 0 ); + const dim_t m_left = m % MR; // Determine some increments used to step through A, B, and C. - inc_t rstep_a = ps_a * dt_size; + const inc_t rstep_a = ps_a * dt_size; - inc_t cstep_b = ps_b * dt_size; + const inc_t cstep_b = ps_b * dt_size; - inc_t rstep_c = rs_c * MR * dt_c_size; - inc_t cstep_c = cs_c * NR * dt_c_size; + const inc_t rstep_c = rs_c * MR * dt_c_size; + const inc_t cstep_c = cs_c * NR * dt_c_size; auxinfo_t aux; @@ -255,20 +251,19 @@ void bli_gemm_ker_var2 thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); // Query the number of threads and thread ids for each loop. - dim_t jr_nt = bli_thrinfo_n_way( thread ); - dim_t jr_tid = bli_thrinfo_work_id( thread ); - dim_t ir_nt = bli_thrinfo_n_way( caucus ); - dim_t ir_tid = bli_thrinfo_work_id( caucus ); + const dim_t jr_nt = bli_thrinfo_n_way( thread ); + const dim_t jr_tid = bli_thrinfo_work_id( thread ); + const dim_t ir_nt = bli_thrinfo_n_way( caucus ); + const dim_t ir_tid = bli_thrinfo_work_id( caucus ); - dim_t jr_start, jr_end; - dim_t ir_start, ir_end; - dim_t jr_inc, ir_inc; + dim_t jr_start, jr_end, jr_inc; + dim_t ir_start, ir_end, ir_inc; // Determine the thread range and increment for the 2nd and 1st loops. - // NOTE: The definition of bli_thread_range_jrir() will depend on whether + // NOTE: The definition of bli_thread_range_slrr() will depend on whether // slab or round-robin partitioning was requested at configure-time. - bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); - bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); + bli_thread_range_slrr( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); + bli_thread_range_slrr( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); // Loop over the n dimension (NR columns at a time). for ( dim_t j = jr_start; j < jr_end; j += jr_inc ) @@ -276,7 +271,9 @@ void bli_gemm_ker_var2 const char* b1 = b_cast + j * cstep_b; char* c1 = c_cast + j * cstep_c; - const dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); + // Compute the current microtile's width. + const dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) + ? NR : n_left ); // Initialize our next panel of B to be the current panel of B. const char* b2 = b1; @@ -287,15 +284,17 @@ void bli_gemm_ker_var2 const char* a1 = a_cast + i * rstep_a; char* c11 = c1 + i * rstep_c; - const dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); + // Compute the current microtile's length. + const dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) + ? MR : m_left ); // Compute the addresses of the next panels of A and B. const char* a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc ); - if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) + if ( bli_is_last_iter_slrr( i, ir_end, ir_tid, ir_nt ) ) { a2 = a_cast; b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc ); - if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) + if ( bli_is_last_iter_slrr( j, jr_end, jr_tid, jr_nt ) ) b2 = b_cast; } @@ -342,22 +341,20 @@ void bli_gemm_ker_var2 ( cntx_t* )cntx ); - // Accumulate to C with type-casting. + // Accumulate to C with typecasting. xpbys_mxn[ dt_exec ][ dt_c ] ( - m_cur, n_cur, - &ct, rs_ct, cs_ct, - ( void* )beta_cast, - c11, rs_c, cs_c + m_cur, n_cur, + &ct, rs_ct, cs_ct, + ( void* )beta_cast, + c11, rs_c, cs_c ); } } } - -/* -PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" ); -PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: a1", MR, k, a1, 1, MR, "%4.1f", "" ); -PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: c after", m_cur, n_cur, c11, rs_c, cs_c, "%4.1f", "" ); -*/ } +//PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" ); +//PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: a1", MR, k, a1, 1, MR, "%4.1f", "" ); +//PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: c after", m_cur, n_cur, c11, rs_c, cs_c, "%4.1f", "" ); + diff --git a/frame/3/gemm/bli_gemm_ker_var2b.c b/frame/3/gemm/bli_gemm_ker_var2b.c new file mode 100644 index 0000000000..50375708af --- /dev/null +++ b/frame/3/gemm/bli_gemm_ker_var2b.c @@ -0,0 +1,379 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +typedef void (*xpbys_mxn_vft) + ( + dim_t m, + dim_t n, + void* x, inc_t rs_x, inc_t cs_x, + void* b, + void* y, inc_t rs_y, inc_t cs_y + ); + +#undef GENTFUNC2 +#define GENTFUNC2(ctypex,ctypey,chx,chy,op) \ +\ +BLIS_INLINE void PASTEMAC2(chx,chy,op) \ + ( \ + dim_t m, \ + dim_t n, \ + void* x, inc_t rs_x, inc_t cs_x, \ + void* b, \ + void* y, inc_t rs_y, inc_t cs_y \ + ) \ +{ \ + ctypex* restrict x_cast = x; \ + ctypey* restrict b_cast = b; \ + ctypey* restrict y_cast = y; \ +\ + PASTEMAC3(chx,chy,chy,xpbys_mxn) \ + ( \ + m, n, \ + x_cast, rs_x, cs_x, \ + b_cast, \ + y_cast, rs_y, cs_y \ + ); \ +} + +INSERT_GENTFUNC2_BASIC0(xpbys_mxnb_fn); +INSERT_GENTFUNC2_MIXDP0(xpbys_mxnb_fn); + +static xpbys_mxn_vft GENARRAY2_ALL(xpbys_mxn, xpbys_mxnb_fn); + + +void bli_gemm_ker_var2b + ( + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntx_t* cntx, + const cntl_t* cntl, + thrinfo_t* thread_par + ) +{ + num_t dt_exec = bli_obj_exec_dt( c ); + num_t dt_c = bli_obj_dt( c ); + + const pack_t schema_a = bli_obj_pack_schema( a ); + const pack_t schema_b = bli_obj_pack_schema( b ); + + dim_t m = bli_obj_length( c ); + dim_t n = bli_obj_width( c ); + dim_t k = bli_obj_width( a ); + + const char* a_cast = bli_obj_buffer_at_off( a ); + const inc_t is_a = bli_obj_imag_stride( a ); + dim_t pd_a = bli_obj_panel_dim( a ); + inc_t ps_a = bli_obj_panel_stride( a ); + + const char* b_cast = bli_obj_buffer_at_off( b ); + const inc_t is_b = bli_obj_imag_stride( b ); + dim_t pd_b = bli_obj_panel_dim( b ); + inc_t ps_b = bli_obj_panel_stride( b ); + + char* c_cast = bli_obj_buffer_at_off( c ); + inc_t rs_c = bli_obj_row_stride( c ); + inc_t cs_c = bli_obj_col_stride( c ); + + // If any dimension is zero, return immediately. + if ( bli_zero_dim3( m, n, k ) ) return; + + // Detach and multiply the scalars attached to A and B. + // NOTE: We know that the internal scalars of A and B are already of the + // target datatypes because the necessary typecasting would have already + // taken place during bli_packm_init(). + obj_t scalar_a, scalar_b; + bli_obj_scalar_detach( a, &scalar_a ); + bli_obj_scalar_detach( b, &scalar_b ); + bli_mulsc( &scalar_a, &scalar_b ); + + // Grab the addresses of the internal scalar buffers for the scalar + // merged above and the scalar attached to C. + // NOTE: We know that scalar_b is of type dt_exec due to the above code + // that casts the scalars of A and B to dt_exec via scalar_a and scalar_b, + // and we know that the internal scalar in C is already of the type dt_c + // due to the casting in the implementation of bli_obj_scalar_attach(). + const char* alpha_cast = bli_obj_internal_scalar_buffer( &scalar_b ); + const char* beta_cast = bli_obj_internal_scalar_buffer( c ); + + // If 1m is being employed on a column- or row-stored matrix with a + // real-valued beta, we can use the real domain macro-kernel, which + // eliminates a little overhead associated with the 1m virtual + // micro-kernel. + // Only employ this optimization if the storage datatype of C is + // equal to the execution/computation datatype. +#if 1 + if ( bli_cntx_method( cntx ) == BLIS_1M ) + { + bli_gemm_ind_recast_1m_params + ( + &dt_exec, + &dt_c, + schema_a, + c, + &m, &n, &k, + &pd_a, &ps_a, + &pd_b, &ps_b, + &rs_c, &cs_c, + cntx + ); + } +#endif + +#ifdef BLIS_ENABLE_GEMM_MD + // Tweak parameters in select mixed domain cases (rcc, crc, ccr). + if ( bli_cntx_method( cntx ) == BLIS_NAT ) + { + bli_gemm_md_ker_var2_recast + ( + &dt_exec, + bli_obj_dt( a ), + bli_obj_dt( b ), + &dt_c, + &m, &n, &k, + &pd_a, &ps_a, + &pd_b, &ps_b, + c, + &rs_c, &cs_c + ); + } +#endif + + const siz_t dt_size = bli_dt_size( dt_exec ); + const siz_t dt_c_size = bli_dt_size( dt_c ); + + // Alias some constants to simpler names. + const dim_t MR = pd_a; + const dim_t NR = pd_b; + + // Query the context for the micro-kernel address and cast it to its + // function pointer type. + gemm_ukr_vft gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt_exec, BLIS_GEMM_UKR, cntx ); + + // Query the params field from the obj_t. If it is non-NULL, grab the ukr + // field of the params struct. If that function pointer is non-NULL, use it + // as our microkernel instead of the default microkernel queried from the + // cntx above. + const gemm_ker_params_t* params = bli_obj_ker_params( c ); + gemm_ukr_vft user_ukr = params ? params->ukr : NULL; + if ( user_ukr ) gemm_ukr = user_ukr; + + // Temporary C buffer for edge cases. Note that the strides of this + // temporary buffer are set so that they match the storage of the + // original C matrix. For example, if C is column-stored, ct will be + // column-stored as well. + char ct[ BLIS_STACK_BUF_MAX_SIZE ] + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); + const bool col_pref = bli_cntx_ukr_prefers_cols_dt( dt_exec, BLIS_GEMM_VIR_UKR, cntx ); + const inc_t rs_ct = ( col_pref ? 1 : NR ); + const inc_t cs_ct = ( col_pref ? MR : 1 ); + const char* zero = bli_obj_buffer_for_const( dt_exec, &BLIS_ZERO ); + + // + // Assumptions/assertions: + // rs_a == 1 + // cs_a == PACKMR + // pd_a == MR + // ps_a == stride to next micro-panel of A + // rs_b == PACKNR + // cs_b == 1 + // pd_b == NR + // ps_b == stride to next micro-panel of B + // rs_c == (no assumptions) + // cs_c == (no assumptions) + // + + // Compute number of primary and leftover components of the m and n + // dimensions. + const dim_t n_iter = n / NR + ( n % NR ? 1 : 0 ); + const dim_t n_left = n % NR; + + const dim_t m_iter = m / MR + ( m % MR ? 1 : 0 ); + const dim_t m_left = m % MR; + + // Determine some increments used to step through A, B, and C. + const inc_t rstep_a = ps_a * dt_size; + + const inc_t cstep_b = ps_b * dt_size; + + const inc_t rstep_c = rs_c * MR * dt_c_size; + const inc_t cstep_c = cs_c * NR * dt_c_size; + + auxinfo_t aux; + + // Save the pack schemas of A and B to the auxinfo_t object. + bli_auxinfo_set_schema_a( schema_a, &aux ); + bli_auxinfo_set_schema_b( schema_b, &aux ); + + // Save the imaginary stride of A and B to the auxinfo_t object. + bli_auxinfo_set_is_a( is_a, &aux ); + bli_auxinfo_set_is_b( is_b, &aux ); + + // Save the virtual microkernel address and the params. + bli_auxinfo_set_ukr( gemm_ukr, &aux ); + bli_auxinfo_set_params( params, &aux ); + + // The 'thread' argument points to the thrinfo_t node for the 2nd (jr) + // loop around the microkernel. Notice that this variant doesn't utilize + // parallelism in the 1st (ir) loop around the microkernel. + thrinfo_t* thread = bli_thrinfo_sub_node( thread_par ); + //thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); + + const dim_t jr_nt = bli_thrinfo_n_way( thread ); + const dim_t jr_tid = bli_thrinfo_work_id( thread ); + //const dim_t ir_nt = bli_thrinfo_n_way( caucus ); + //const dim_t ir_tid = bli_thrinfo_work_id( caucus ); + + // Determine the starting microtile offsets and number of microtiles to + // compute for each thread. Note that assignment of microtiles is done + // according to the tlb policy. + dim_t jr_st, ir_st; + const dim_t n_ut_for_me + = + bli_thread_range_tlb_d( jr_nt, jr_tid, m_iter, n_iter, MR, NR, &jr_st, &ir_st ); + + // It's possible that there are so few microtiles relative to the number + // of threads that one or more threads gets no work. If that happens, those + // threads can return early. + if ( n_ut_for_me == 0 ) return; + + // Start the jr/ir loops with the current thread's microtile offsets computed + // by bli_thread_range_tlb(). + dim_t i = ir_st; + dim_t j = jr_st; + + // Initialize a counter to track the number of microtiles computed by the + // current thread. + dim_t ut = 0; + + // Loop over the n dimension (NR columns at a time). + for ( ; true; ++j ) + { + const char* b1 = b_cast + j * cstep_b; + char* c1 = c_cast + j * cstep_c; + + // Compute the current microtile's width. + const dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) + ? NR : n_left ); + + // Initialize our next panel of B to be the current panel of B. + const char* b2 = b1; + + bli_auxinfo_set_next_b( b2, &aux ); + + // Loop over the m dimension (MR rows at a time). + for ( ; i < m_iter; ++i ) + { + const char* a1 = a_cast + i * rstep_a; + char* c11 = c1 + i * rstep_c; + + // Compute the current microtile's length. + const dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) + ? MR : m_left ); + + // Compute the addresses of the next panels of A and B. + const char* a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, 1 ); + if ( bli_is_last_iter_sl( i, m_iter ) ) + { + a2 = a_cast; + b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, 1 ); + bli_auxinfo_set_next_b( b2, &aux ); + } + + // Save addresses of next panels of A and B to the auxinfo_t + // object. + bli_auxinfo_set_next_a( a2, &aux ); + + // Edge case handling now occurs within the microkernel itself, but + // we must still explicitly accumulate to a temporary microtile in + // situations where a virtual microkernel is being used, such as + // during the 1m method or some cases of mixed datatypes. + if ( dt_exec == dt_c ) + { + // Invoke the gemm micro-kernel. + gemm_ukr + ( + m_cur, + n_cur, + k, + ( void* )alpha_cast, + ( void* )a1, + ( void* )b1, + ( void* )beta_cast, + c11, rs_c, cs_c, + &aux, + ( cntx_t* )cntx + ); + } + else + { + // Invoke the gemm micro-kernel. + gemm_ukr + ( + MR, + NR, + k, + ( void* )alpha_cast, + ( void* )a1, + ( void* )b1, + ( void* )zero, + &ct, rs_ct, cs_ct, + &aux, + ( cntx_t* )cntx + ); + + // Accumulate to C with typecasting. + xpbys_mxn[ dt_exec ][ dt_c ] + ( + m_cur, n_cur, + &ct, rs_ct, cs_ct, + ( void* )beta_cast, + c11, rs_c, cs_c + ); + } + + ut += 1; + if ( ut == n_ut_for_me ) return; + } + + i = 0; + } +} + +//PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2b: b1", k, NR, b1, NR, 1, "%4.1f", "" ); +//PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2b: a1", MR, k, a1, 1, MR, "%4.1f", "" ); +//PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2b: c after", m_cur, n_cur, c11, rs_c, cs_c, "%4.1f", "" ); diff --git a/frame/3/gemm/bli_gemm_var.h b/frame/3/gemm/bli_gemm_var.h index 24f7ecfb9e..f69327db0c 100644 --- a/frame/3/gemm/bli_gemm_var.h +++ b/frame/3/gemm/bli_gemm_var.h @@ -65,6 +65,7 @@ GENPROT( gemm_blk_var1 ) GENPROT( gemm_blk_var2 ) GENPROT( gemm_blk_var3 ) -GENPROT( gemm_ker_var1 ) GENPROT( gemm_ker_var2 ) +GENPROT( gemm_ker_var2b ) + diff --git a/frame/3/gemmt/attic/bli_gemmt_l_ker_var2b.c b/frame/3/gemmt/attic/bli_gemmt_l_ker_var2b.c new file mode 100644 index 0000000000..fbfafebb0e --- /dev/null +++ b/frame/3/gemmt/attic/bli_gemmt_l_ker_var2b.c @@ -0,0 +1,429 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +typedef void (*xpbys_mxn_l_vft) + ( + doff_t diagoff, + dim_t m, + dim_t n, + void* x, inc_t rs_x, inc_t cs_x, + void* b, + void* y, inc_t rs_y, inc_t cs_y + ); + +#undef GENTFUNC +#define GENTFUNC(ctype,ch,op) \ +\ +void PASTEMAC(ch,op) \ + ( \ + doff_t diagoff, \ + dim_t m, \ + dim_t n, \ + void* x, inc_t rs_x, inc_t cs_x, \ + void* b, \ + void* y, inc_t rs_y, inc_t cs_y \ + ) \ +{ \ + ctype* restrict x_cast = x; \ + ctype* restrict b_cast = b; \ + ctype* restrict y_cast = y; \ +\ + PASTEMAC3(ch,ch,ch,xpbys_mxn_l) \ + ( \ + diagoff, \ + m, n, \ + x_cast, rs_x, cs_x, \ + b_cast, \ + y_cast, rs_y, cs_y \ + ); \ +} + +INSERT_GENTFUNC_BASIC0(xpbys_mxn_l_fn); + +static xpbys_mxn_l_vft GENARRAY(xpbys_mxn_l, xpbys_mxn_l_fn); + +void bli_gemmt_l_ker_var2b + ( + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntx_t* cntx, + const cntl_t* cntl, + thrinfo_t* thread_par + ) +{ + const num_t dt = bli_obj_exec_dt( c ); + const dim_t dt_size = bli_dt_size( dt ); + + doff_t diagoffc = bli_obj_diag_offset( c ); + + const pack_t schema_a = bli_obj_pack_schema( a ); + const pack_t schema_b = bli_obj_pack_schema( b ); + + dim_t m = bli_obj_length( c ); + dim_t n = bli_obj_width( c ); + dim_t k = bli_obj_width( a ); + + const void* buf_a = bli_obj_buffer_at_off( a ); + const inc_t is_a = bli_obj_imag_stride( a ); + const dim_t pd_a = bli_obj_panel_dim( a ); + const inc_t ps_a = bli_obj_panel_stride( a ); + + const void* buf_b = bli_obj_buffer_at_off( b ); + const inc_t is_b = bli_obj_imag_stride( b ); + const dim_t pd_b = bli_obj_panel_dim( b ); + const inc_t ps_b = bli_obj_panel_stride( b ); + + void* buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + // Detach and multiply the scalars attached to A and B. + obj_t scalar_a, scalar_b; + bli_obj_scalar_detach( a, &scalar_a ); + bli_obj_scalar_detach( b, &scalar_b ); + bli_mulsc( &scalar_a, &scalar_b ); + + // Grab the addresses of the internal scalar buffers for the scalar + // merged above and the scalar attached to C. + const void* buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); + const void* buf_beta = bli_obj_internal_scalar_buffer( c ); + + // Index into the type combination array to extract the correct + // function pointer. + ftypes[dt_exec] + ( + diagoffc, + schema_a, + schema_b, + m, + n, + k, + ( void* )buf_alpha, + ( void* )buf_a, cs_a, is_a, + pd_a, ps_a, + ( void* )buf_b, rs_b, is_b, + pd_b, ps_b, + ( void* )buf_beta, + buf_c, rs_c, cs_c, + ( cntx_t* )cntx, + rntm, + thread + ); +} + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + doff_t diagoffc, \ + pack_t schema_a, \ + pack_t schema_b, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* alpha, \ + void* a, inc_t cs_a, inc_t is_a, \ + dim_t pd_a, inc_t ps_a, \ + void* b, inc_t rs_b, inc_t is_b, \ + dim_t pd_b, inc_t ps_b, \ + void* beta, \ + void* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx, \ + rntm_t* rntm, \ + thrinfo_t* thread \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Alias some constants to simpler names. */ \ + const dim_t MR = pd_a; \ + const dim_t NR = pd_b; \ + /*const dim_t PACKMR = cs_a;*/ \ + /*const dim_t PACKNR = rs_b;*/ \ +\ + /* Query the context for the micro-kernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemm_ukr_ft) \ + gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ +\ + /* Temporary C buffer for edge cases. Note that the strides of this + temporary buffer are set so that they match the storage of the + original C matrix. For example, if C is column-stored, ct will be + column-stored as well. */ \ + ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const bool col_pref = bli_cntx_ukr_prefers_cols_dt( dt, BLIS_GEMM_VIR_UKR, cntx ); \ + const inc_t rs_ct = ( col_pref ? 1 : NR ); \ + const inc_t cs_ct = ( col_pref ? MR : 1 ); \ +\ + ctype* restrict zero = PASTEMAC(ch,0); \ + ctype* restrict a_cast = a; \ + ctype* restrict b_cast = b; \ + ctype* restrict c_cast = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ +\ + auxinfo_t aux; \ +\ + /* + Assumptions/assertions: + rs_a == 1 + cs_a == PACKMR + pd_a == MR + ps_a == stride to next micro-panel of A + rs_b == PACKNR + cs_b == 1 + pd_b == NR + ps_b == stride to next micro-panel of B + rs_c == (no assumptions) + cs_c == (no assumptions) + */ \ +\ + /* If any dimension is zero, return immediately. */ \ + if ( bli_zero_dim3( m, n, k ) ) return; \ +\ + /* Safeguard: If the current panel of C is entirely above the diagonal, + it is not stored. So we do nothing. */ \ + if ( bli_is_strictly_above_diag_n( diagoffc, m, n ) ) return; \ +\ + /* If there is a zero region above where the diagonal of C intersects + the left edge of the panel, adjust the pointer to C and A and treat + this case as if the diagonal offset were zero. + NOTE: It's possible that after this pruning that the diagonal offset + is still negative (though its absolute value is guaranteed to be less + than MR). */ \ + if ( diagoffc < 0 ) \ + { \ + const dim_t ip = -diagoffc / MR; \ + const dim_t i = ip * MR; \ +\ + m = m - i; \ + diagoffc = diagoffc % MR; \ + c_cast = c_cast + (i )*rs_c; \ + a_cast = a_cast + (ip )*ps_a; \ + } \ +\ + /* If there is a zero region to the right of where the diagonal + of C intersects the bottom of the panel, shrink it to prevent + "no-op" iterations from executing. */ \ + if ( diagoffc + m < n ) \ + { \ + n = diagoffc + m; \ + } \ +\ + /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ + PASTEMAC(ch,set0s_mxn)( MR, NR, \ + ct, rs_ct, cs_ct ); \ +\ + /* Compute number of primary and leftover components of the m and n + dimensions. */ \ + const dim_t n_iter = n / NR + ( n % NR ? 1 : 0 ); \ + const dim_t n_left = n % NR; \ +\ + const dim_t m_iter = m / MR + ( m % MR ? 1 : 0 ); \ + const dim_t m_left = m % MR; \ +\ + /* Determine some increments used to step through A, B, and C. */ \ + const inc_t rstep_a = ps_a; \ +\ + const inc_t cstep_b = ps_b; \ +\ + const inc_t rstep_c = rs_c * MR; \ + const inc_t cstep_c = cs_c * NR; \ +\ + /* Save the pack schemas of A and B to the auxinfo_t object. */ \ + bli_auxinfo_set_schema_a( schema_a, &aux ); \ + bli_auxinfo_set_schema_b( schema_b, &aux ); \ +\ + /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ + bli_auxinfo_set_is_a( is_a, &aux ); \ + bli_auxinfo_set_is_b( is_b, &aux ); \ +\ + const dim_t jr_inc = 1; \ + const dim_t ir_inc = 1; \ +\ + /* Determine the starting microtile offsets and number of microtiles to + compute for each thread. Note that assignment of microtiles is done + according to the tlb policy. */ \ + dim_t jr_st, ir_st; \ + const dim_t n_ut_for_me \ + = \ + bli_thread_range_tlb( thread, diagoffc, BLIS_LOWER, m, n, MR, NR, \ + &jr_st, &ir_st ); \ +\ + /* It's possible that there are so few microtiles relative to the number + of threads that one or more threads gets no work. If that happens, those + threads can return early. */ \ + if ( n_ut_for_me == 0 ) return; \ +\ + /* Start the jr/ir loops with the current thread's microtile offsets computed + by bli_thread_range_tlb(). */ \ + dim_t i = ir_st; \ + dim_t j = jr_st; \ +\ + /* Initialize a counter to track the number of microtiles computed by the + current thread. */ \ + dim_t ut = 0; \ +\ + /* Loop over the n dimension (NR columns at a time). */ \ + for ( ; true; ++j ) \ + { \ + ctype* restrict b1 = b_cast + j * cstep_b; \ + ctype* restrict c1 = c_cast + j * cstep_c; \ +\ + /* Compute the diagonal offset for the column of microtiles at (0,j). */ \ + const doff_t diagoffc_j = diagoffc - (doff_t)j*NR; \ + const dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) \ + ? NR : n_left ); \ +\ + /* Initialize our next panel of B to be the current panel of B. */ \ + ctype* restrict b2 = b1; \ +\ + /* Interior loop over the m dimension (MR rows at a time). */ \ + for ( ; i < m_iter; ++i ) \ + { \ + /* Compute the diagonal offset for the microtile at (i,j). */ \ + const doff_t diagoffc_ij = diagoffc_j + (doff_t)i*MR; \ + const dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) \ + ? MR : m_left ); \ +\ + /* If the diagonal intersects the current MR x NR microtile, we + compute it the temporary buffer and then add in the elements + on or below the diagonal. + Otherwise, if the microtile is strictly below the diagonal, + we compute and store as we normally would. + And if we're strictly above the diagonal, we simply advance + to last microtile before the diagonal. */ \ + if ( bli_intersects_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ + { \ + ctype* restrict a1 = a_cast + i * rstep_a; \ + ctype* restrict c11 = c1 + i * rstep_c; \ +\ + /* Compute the addresses of the next panels of A and B. */ \ + ctype* restrict a2 \ + = bli_gemmt_get_next_a_upanel( a1, rstep_a, ir_inc ); \ +\ + /* Save addresses of next panels of A and B to the auxinfo_t + object. */ \ + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ +\ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + MR, \ + NR, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + zero, \ + ct, rs_ct, cs_ct, \ + &aux, \ + cntx \ + ); \ +\ + /* Scale C and add the result to only the stored part. */ \ + PASTEMAC(ch,xpbys_mxn_l)( diagoffc_ij, \ + m_cur, n_cur, \ + ct, rs_ct, cs_ct, \ + beta_cast, \ + c11, rs_c, cs_c ); \ +\ + ut += 1; \ + if ( ut == n_ut_for_me ) return; \ + } \ + else if ( bli_is_strictly_below_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ + { \ + ctype* restrict a1 = a_cast + i * rstep_a; \ + ctype* restrict c11 = c1 + i * rstep_c; \ +\ + /* Compute the addresses of the next panels of A and B. */ \ + ctype* restrict a2 \ + = bli_gemmt_get_next_a_upanel( a1, rstep_a, ir_inc ); \ + if ( bli_is_last_iter_tlb_l( i, m_iter ) ) \ + { \ + a2 = bli_gemmt_l_wrap_a_upanel( a_cast, rstep_a, \ + diagoffc_j, MR, NR ); \ + b2 = bli_gemmt_get_next_b_upanel( b1, cstep_b, jr_inc ); \ + /* We don't bother computing b2 for the last iteration of the + jr loop since the current thread won't know its j_st until + the next time it calls bli_thread_range_tlb(). */ \ + } \ +\ + /* Save addresses of next panels of A and B to the auxinfo_t + object. */ \ + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ +\ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ +\ + ut += 1; \ + if ( ut == n_ut_for_me ) return; \ + } \ + else /* if ( bli_is_strictly_above_diag_n( diagoffc_ij, m_cur, n_cur ) ) */ \ + { \ + /* Skip ahead to the last microtile strictly above the diagonal. */ \ + i = -diagoffc_j / MR - 1; \ + } \ + } \ +\ + /* Upon reaching the end of the column of microtiles, get ready to begin at + the beginning of the next column (i.e., the next jr loop iteration). */ \ + i = 0; \ + } \ +} + +INSERT_GENTFUNC_BASIC0( gemmt_l_ker_var2b ) + diff --git a/frame/3/gemmt/attic/bli_gemmt_u_ker_var2b.c b/frame/3/gemmt/attic/bli_gemmt_u_ker_var2b.c new file mode 100644 index 0000000000..311180d192 --- /dev/null +++ b/frame/3/gemmt/attic/bli_gemmt_u_ker_var2b.c @@ -0,0 +1,418 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define FUNCPTR_T gemmt_fp + +typedef void (*FUNCPTR_T) + ( + doff_t diagoffc, + pack_t schema_a, + pack_t schema_b, + dim_t m, + dim_t n, + dim_t k, + void* alpha, + void* a, inc_t cs_a, inc_t is_a, + dim_t pd_a, inc_t ps_a, + void* b, inc_t rs_b, inc_t is_b, + dim_t pd_b, inc_t ps_b, + void* beta, + void* c, inc_t rs_c, inc_t cs_c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ); + +static FUNCPTR_T GENARRAY(ftypes,gemmt_u_ker_var2b); + + +void bli_gemmt_u_ker_var2b + ( + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntx_t* cntx, + cntl_t* cntl, + thrinfo_t* thread + ) +{ + const num_t dt_exec = bli_obj_exec_dt( c ); + + const doff_t diagoffc = bli_obj_diag_offset( c ); + + const pack_t schema_a = bli_obj_pack_schema( a ); + const pack_t schema_b = bli_obj_pack_schema( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + const dim_t k = bli_obj_width( a ); + + const void* buf_a = bli_obj_buffer_at_off( a ); + const inc_t cs_a = bli_obj_col_stride( a ); + const inc_t is_a = bli_obj_imag_stride( a ); + const dim_t pd_a = bli_obj_panel_dim( a ); + const inc_t ps_a = bli_obj_panel_stride( a ); + + const void* buf_b = bli_obj_buffer_at_off( b ); + const inc_t rs_b = bli_obj_row_stride( b ); + const inc_t is_b = bli_obj_imag_stride( b ); + const dim_t pd_b = bli_obj_panel_dim( b ); + const inc_t ps_b = bli_obj_panel_stride( b ); + + void* buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + // Detach and multiply the scalars attached to A and B. + obj_t scalar_a, scalar_b; + bli_obj_scalar_detach( a, &scalar_a ); + bli_obj_scalar_detach( b, &scalar_b ); + bli_mulsc( &scalar_a, &scalar_b ); + + // Grab the addresses of the internal scalar buffers for the scalar + // merged above and the scalar attached to C. + const void* buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); + const void* buf_beta = bli_obj_internal_scalar_buffer( c ); + + // Index into the type combination array to extract the correct + // function pointer. + ftypes[dt_exec] + ( + diagoffc, + schema_a, + schema_b, + m, + n, + k, + ( void* )buf_alpha, + ( void* )buf_a, cs_a, is_a, + pd_a, ps_a, + ( void* )buf_b, rs_b, is_b, + pd_b, ps_b, + ( void* )buf_beta, + buf_c, rs_c, cs_c, + ( cntx_t* )cntx, + rntm, + thread + ); +} + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + doff_t diagoffc, \ + pack_t schema_a, \ + pack_t schema_b, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* alpha, \ + void* a, inc_t cs_a, inc_t is_a, \ + dim_t pd_a, inc_t ps_a, \ + void* b, inc_t rs_b, inc_t is_b, \ + dim_t pd_b, inc_t ps_b, \ + void* beta, \ + void* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx, \ + rntm_t* rntm, \ + thrinfo_t* thread \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Alias some constants to simpler names. */ \ + const dim_t MR = pd_a; \ + const dim_t NR = pd_b; \ + /*const dim_t PACKMR = cs_a;*/ \ + /*const dim_t PACKNR = rs_b;*/ \ +\ + /* Query the context for the micro-kernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemm_ukr_ft) \ + gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ +\ + /* Temporary C buffer for edge cases. Note that the strides of this + temporary buffer are set so that they match the storage of the + original C matrix. For example, if C is column-stored, ct will be + column-stored as well. */ \ + ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const bool col_pref = bli_cntx_ukr_prefers_cols_dt( dt, BLIS_GEMM_VIR_UKR, cntx ); \ + const inc_t rs_ct = ( col_pref ? 1 : NR ); \ + const inc_t cs_ct = ( col_pref ? MR : 1 ); \ +\ + ctype* restrict zero = PASTEMAC(ch,0); \ + ctype* restrict a_cast = a; \ + ctype* restrict b_cast = b; \ + ctype* restrict c_cast = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ +\ + auxinfo_t aux; \ +\ + /* + Assumptions/assertions: + rs_a == 1 + cs_a == PACKMR + pd_a == MR + ps_a == stride to next micro-panel of A + rs_b == PACKNR + cs_b == 1 + pd_b == NR + ps_b == stride to next micro-panel of B + rs_c == (no assumptions) + cs_c == (no assumptions) + */ \ +\ + /* If any dimension is zero, return immediately. */ \ + if ( bli_zero_dim3( m, n, k ) ) return; \ +\ + /* Safeguard: If the current panel of C is entirely below the diagonal, + it is not stored. So we do nothing. */ \ + if ( bli_is_strictly_below_diag_n( diagoffc, m, n ) ) return; \ +\ + /* If there is a zero region to the left of where the diagonal of C + intersects the top edge of the panel, adjust the pointer to C and B + and treat this case as if the diagonal offset were zero. + NOTE: It's possible that after this pruning that the diagonal offset + is still positive (though it is guaranteed to be less than NR). */ \ + if ( diagoffc > 0 ) \ + { \ + const dim_t jp = diagoffc / NR; \ + const dim_t j = jp * NR; \ +\ + n = n - j; \ + diagoffc = diagoffc % NR; \ + c_cast = c_cast + (j )*cs_c; \ + b_cast = b_cast + (jp )*ps_b; \ + } \ +\ + /* If there is a zero region below where the diagonal of C intersects + the right edge of the panel, shrink it to prevent "no-op" iterations + from executing. */ \ + if ( -diagoffc + n < m ) \ + { \ + m = -diagoffc + n; \ + } \ +\ + /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ + PASTEMAC(ch,set0s_mxn)( MR, NR, \ + ct, rs_ct, cs_ct ); \ +\ + /* Compute number of primary and leftover components of the m and n + dimensions. */ \ + const dim_t n_iter = n / NR + ( n % NR ? 1 : 0 ); \ + const dim_t n_left = n % NR; \ +\ + const dim_t m_iter = m / MR + ( m % MR ? 1 : 0 ); \ + const dim_t m_left = m % MR; \ +\ + /* Determine some increments used to step through A, B, and C. */ \ + const inc_t rstep_a = ps_a; \ +\ + const inc_t cstep_b = ps_b; \ +\ + const inc_t rstep_c = rs_c * MR; \ + const inc_t cstep_c = cs_c * NR; \ +\ + /* Save the pack schemas of A and B to the auxinfo_t object. */ \ + bli_auxinfo_set_schema_a( schema_a, &aux ); \ + bli_auxinfo_set_schema_b( schema_b, &aux ); \ +\ + /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ + bli_auxinfo_set_is_a( is_a, &aux ); \ + bli_auxinfo_set_is_b( is_b, &aux ); \ +\ + /* Save the virtual microkernel address and the params. */ \ + /*bli_auxinfo_set_ukr( gemm_ukr, &aux );*/ \ + /*bli_auxinfo_set_params( params, &aux );*/ \ +\ + /* Save the desired output datatype (indicating no typecasting). */ \ + /*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \ +\ + const dim_t jr_inc = 1; \ + const dim_t ir_inc = 1; \ +\ + /* Determine the starting microtile offsets and number of microtiles to + compute for each thread. Note that assignment of microtiles is done + according to the tlb policy. */ \ + dim_t jr_st, ir_st; \ + const dim_t n_ut_for_me \ + = \ + bli_thread_range_tlb( thread, diagoffc, BLIS_UPPER, m, n, MR, NR, \ + &jr_st, &ir_st ); \ +\ + /* It's possible that there are so few microtiles relative to the number + of threads that one or more threads gets no work. If that happens, those + threads can return early. */ \ + if ( n_ut_for_me == 0 ) return; \ +\ + /* Start the jr/ir loops with the current thread's microtile offsets computed + by bli_thread_range_tlb(). */ \ + dim_t i = ir_st; \ + dim_t j = jr_st; \ +\ + /* Initialize a counter to track the number of microtiles computed by the + current thread. */ \ + dim_t ut = 0; \ +\ + /* Loop over the n dimension (NR columns at a time). */ \ + for ( ; true; ++j ) \ + { \ + ctype* restrict b1 = b_cast + j * cstep_b; \ + ctype* restrict c1 = c_cast + j * cstep_c; \ +\ + /* Compute the diagonal offset for the column of microtiles at (0,j). */ \ + const doff_t diagoffc_j = diagoffc - (doff_t)j*NR; \ + const dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) \ + ? NR : n_left ); \ +\ + /* Initialize our next panel of B to be the current panel of B. */ \ + ctype* restrict b2 = b1; \ +\ + /* Interior loop over the m dimension (MR rows at a time). */ \ + for ( ; i < m_iter; ++i ) \ + { \ + /* Compute the diagonal offset for the microtile at (i,j). */ \ + const doff_t diagoffc_ij = diagoffc_j + (doff_t)i*MR; \ + const dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) \ + ? MR : m_left ); \ +\ + /* If the diagonal intersects the current MR x NR submatrix, we + compute it the temporary buffer and then add in the elements + on or below the diagonal. + Otherwise, if the submatrix is strictly above the diagonal, + we compute and store as we normally would. + And if we're strictly above the diagonal, we simply advance + to last microtile before the bottom of the matrix. */ \ + if ( bli_intersects_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ + { \ + ctype* restrict a1 = a_cast + i * rstep_a; \ + ctype* restrict c11 = c1 + i * rstep_c; \ +\ + /* Compute the addresses of the next panels of A and B. */ \ + ctype* restrict a2 \ + = bli_gemmt_get_next_a_upanel( a1, rstep_a, ir_inc ); \ + if ( bli_is_last_iter_tlb_u( diagoffc_ij, MR, NR ) ) \ + { \ + a2 = bli_gemmt_u_wrap_a_upanel( a_cast, rstep_a, \ + diagoffc_j, MR, NR ); \ + b2 = bli_gemmt_get_next_b_upanel( b1, cstep_b, jr_inc ); \ + /* We don't bother computing b2 for the last iteration of the + jr loop since the current thread won't know its j_st until + the next time it calls bli_thread_range_tlb(). */ \ + } \ +\ + /* Save addresses of next panels of A and B to the auxinfo_t + object. */ \ + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ +\ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + MR, \ + NR, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + zero, \ + ct, rs_ct, cs_ct, \ + &aux, \ + cntx \ + ); \ +\ + /* Scale C and add the result to only the stored part. */ \ + PASTEMAC(ch,xpbys_mxn_u)( diagoffc_ij, \ + m_cur, n_cur, \ + ct, rs_ct, cs_ct, \ + beta_cast, \ + c11, rs_c, cs_c ); \ +\ + ut += 1; \ + if ( ut == n_ut_for_me ) return; \ + } \ + else if ( bli_is_strictly_above_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ + { \ + ctype* restrict a1 = a_cast + i * rstep_a; \ + ctype* restrict c11 = c1 + i * rstep_c; \ +\ + /* Compute the addresses of the next panels of A and B. */ \ + ctype* restrict a2 \ + = bli_gemmt_get_next_a_upanel( a1, rstep_a, ir_inc ); \ +\ + /* Save addresses of next panels of A and B to the auxinfo_t + object. */ \ + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ +\ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ +\ + ut += 1; \ + if ( ut == n_ut_for_me ) return; \ + } \ + else /* if ( bli_is_strictly_below_diag_n( diagoffc_ij, m_cur, n_cur ) ) */ \ + { \ + /* Skip past the microtiles strictly below the diagonal. */ \ + i = m_iter - 1; \ + } \ + } \ +\ + i = 0; \ + } \ +} + +INSERT_GENTFUNC_BASIC0( gemmt_u_ker_var2b ) + diff --git a/frame/3/gemmt/bli_gemmt_l_ker_var2.c b/frame/3/gemmt/bli_gemmt_l_ker_var2.c index 4a3a48304f..fd726da6f7 100644 --- a/frame/3/gemmt/bli_gemmt_l_ker_var2.c +++ b/frame/3/gemmt/bli_gemmt_l_ker_var2.c @@ -48,7 +48,7 @@ typedef void (*xpbys_mxn_l_vft) #undef GENTFUNC #define GENTFUNC(ctype,ch,op) \ \ -void PASTEMAC(ch,op) \ +BLIS_INLINE void PASTEMAC(ch,op) \ ( \ doff_t diagoff, \ dim_t m, \ @@ -76,18 +76,19 @@ INSERT_GENTFUNC_BASIC0(xpbys_mxn_l_fn); static xpbys_mxn_l_vft GENARRAY(xpbys_mxn_l, xpbys_mxn_l_fn); + void bli_gemmt_l_ker_var2 ( - const obj_t* a, - const obj_t* b, - const obj_t* c, - const cntx_t* cntx, - const cntl_t* cntl, + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntx_t* cntx, + const cntl_t* cntl, thrinfo_t* thread_par ) { - const num_t dt = bli_obj_exec_dt( c ); - const dim_t dt_size = bli_dt_size( dt ); + const num_t dt_exec = bli_obj_exec_dt( c ); + const num_t dt_c = bli_obj_dt( c ); doff_t diagoffc = bli_obj_diag_offset( c ); @@ -113,7 +114,7 @@ void bli_gemmt_l_ker_var2 const inc_t cs_c = bli_obj_col_stride( c ); // Detach and multiply the scalars attached to A and B. - obj_t scalar_a, scalar_b; + obj_t scalar_a, scalar_b; bli_obj_scalar_detach( a, &scalar_a ); bli_obj_scalar_detach( b, &scalar_b ); bli_mulsc( &scalar_a, &scalar_b ); @@ -123,14 +124,17 @@ void bli_gemmt_l_ker_var2 const void* buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); const void* buf_beta = bli_obj_internal_scalar_buffer( c ); + const siz_t dt_size = bli_dt_size( dt_exec ); + const siz_t dt_c_size = bli_dt_size( dt_c ); + // Alias some constants to simpler names. const dim_t MR = pd_a; const dim_t NR = pd_b; // Query the context for the micro-kernel address and cast it to its // function pointer type. - gemm_ukr_vft gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); - xpbys_mxn_l_vft xpbys_mxn_l_ukr = xpbys_mxn_l[ dt ]; + gemm_ukr_vft gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt_exec, BLIS_GEMM_UKR, cntx ); + xpbys_mxn_l_vft xpbys_mxn_l_ukr = xpbys_mxn_l[ dt_exec ]; // Temporary C buffer for edge cases. Note that the strides of this // temporary buffer are set so that they match the storage of the @@ -138,11 +142,11 @@ void bli_gemmt_l_ker_var2 // column-stored as well. char ct[ BLIS_STACK_BUF_MAX_SIZE ] __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); - const bool col_pref = bli_cntx_ukr_prefers_cols_dt( dt, BLIS_GEMM_VIR_UKR, cntx ); + const bool col_pref = bli_cntx_ukr_prefers_cols_dt( dt_exec, BLIS_GEMM_VIR_UKR, cntx ); const inc_t rs_ct = ( col_pref ? 1 : NR ); const inc_t cs_ct = ( col_pref ? MR : 1 ); - const void* zero = bli_obj_buffer_for_const( dt, &BLIS_ZERO ); + const void* zero = bli_obj_buffer_for_const( dt_exec, &BLIS_ZERO ); const char* a_cast = buf_a; const char* b_cast = buf_b; char* c_cast = buf_c; @@ -175,12 +179,13 @@ void bli_gemmt_l_ker_var2 // this case as if the diagonal offset were zero. if ( diagoffc < 0 ) { - dim_t ip = -diagoffc / MR; - dim_t i = ip * MR; - m = m - i; - diagoffc = -diagoffc % MR; - c_cast = c_cast + (i )*rs_c*dt_size; - a_cast = a_cast + (ip )*ps_a*dt_size; + const dim_t ip = -diagoffc / MR; + const dim_t i = ip * MR; + + m = m - i; + diagoffc = diagoffc % MR; + c_cast = c_cast + (i )*rs_c*dt_c_size; + a_cast = a_cast + (ip )*ps_a*dt_size; } // If there is a zero region to the right of where the diagonal @@ -193,25 +198,23 @@ void bli_gemmt_l_ker_var2 // Compute number of primary and leftover components of the m and n // dimensions. - dim_t n_iter = n / NR; - dim_t n_left = n % NR; + const dim_t n_iter = n / NR + ( n % NR ? 1 : 0 ); + const dim_t n_left = n % NR; - dim_t m_iter = m / MR; - dim_t m_left = m % MR; - - if ( n_left ) ++n_iter; - if ( m_left ) ++m_iter; + const dim_t m_iter = m / MR + ( m % MR ? 1 : 0 ); + const dim_t m_left = m % MR; // Determine some increments used to step through A, B, and C. - inc_t rstep_a = ps_a * dt_size; + const inc_t rstep_a = ps_a * dt_size; - inc_t cstep_b = ps_b * dt_size; + const inc_t cstep_b = ps_b * dt_size; - inc_t rstep_c = rs_c * MR * dt_size; - inc_t cstep_c = cs_c * NR * dt_size; + const inc_t rstep_c = rs_c * MR * dt_c_size; + const inc_t cstep_c = cs_c * NR * dt_c_size; - // Save the pack schemas of A and B to the auxinfo_t object. auxinfo_t aux; + + // Save the pack schemas of A and B to the auxinfo_t object. bli_auxinfo_set_schema_a( schema_a, &aux ); bli_auxinfo_set_schema_b( schema_b, &aux ); @@ -219,9 +222,6 @@ void bli_gemmt_l_ker_var2 bli_auxinfo_set_is_a( is_a, &aux ); bli_auxinfo_set_is_b( is_b, &aux ); - // Save the desired output datatype (indicating no typecasting). - //bli_auxinfo_set_dt_on_output( dt, &aux );*/ - // The 'thread' argument points to the thrinfo_t node for the 2nd (jr) // loop around the microkernel. Here we query the thrinfo_t node for the // 1st (ir) loop around the microkernel. @@ -229,48 +229,21 @@ void bli_gemmt_l_ker_var2 thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); // Query the number of threads and thread ids for each loop. - dim_t jr_nt = bli_thrinfo_n_way( thread ); - dim_t jr_tid = bli_thrinfo_work_id( thread ); - dim_t ir_nt = bli_thrinfo_n_way( caucus ); - dim_t ir_tid = bli_thrinfo_work_id( caucus ); - - dim_t jr_start, jr_end; - dim_t ir_start, ir_end; - dim_t jr_inc, ir_inc; + const dim_t jr_nt = bli_thrinfo_n_way( thread ); + const dim_t jr_tid = bli_thrinfo_work_id( thread ); + const dim_t ir_nt = bli_thrinfo_n_way( caucus ); + const dim_t ir_tid = bli_thrinfo_work_id( caucus ); - // Note that we partition the 2nd loop into two regions: the rectangular - // part of C, and the triangular portion. - dim_t n_iter_rct; - dim_t n_iter_tri; + dim_t jr_start, jr_end, jr_inc; + dim_t ir_start, ir_end, ir_inc; - if ( bli_is_strictly_below_diag_n( diagoffc, m, n ) ) - { - // If the entire panel of C does not intersect the diagonal, there is - // no triangular region, and therefore we can skip the second set of - // loops. - n_iter_rct = n_iter; - n_iter_tri = 0; - } - else - { - // If the panel of C does intersect the diagonal, compute the number of - // iterations in the rectangular region by dividing NR into the diagonal - // offset. Any remainder from this integer division is discarded, which - // is what we want. That is, we want the rectangular region to contain - // as many columns of whole microtiles as possible without including any - // microtiles that intersect the diagonal. The number of iterations in - // the triangular (or trapezoidal) region is computed as the remaining - // number of iterations in the n dimension. - n_iter_rct = diagoffc / NR; - n_iter_tri = n_iter - n_iter_rct; - } - - // Determine the thread range and increment for the 2nd and 1st loops for - // the initial rectangular region of C (if it exists). - // NOTE: The definition of bli_thread_range_jrir() will depend on whether + // Determine the thread range and increment for the 2nd and 1st loops. + // NOTE: The definition of bli_thread_range_slrr() will depend on whether // slab or round-robin partitioning was requested at configure-time. - bli_thread_range_jrir( thread, n_iter_rct, 1, FALSE, &jr_start, &jr_end, &jr_inc ); - bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); + bli_thread_range_quad( thread, diagoffc, BLIS_LOWER, m, n, NR, + FALSE, &jr_start, &jr_end, &jr_inc ); + //bli_thread_range_slrr( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); + bli_thread_range_slrr( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); // Loop over the n dimension (NR columns at a time). for ( dim_t j = jr_start; j < jr_end; j += jr_inc ) @@ -278,7 +251,12 @@ void bli_gemmt_l_ker_var2 const char* b1 = b_cast + j * cstep_b; char* c1 = c_cast + j * cstep_c; - dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); + // Compute the diagonal offset for the column of microtiles at (0,j). + const doff_t diagoffc_j = diagoffc - ( doff_t )j*NR; + + // Compute the current microtile's width. + const dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) + ? NR : n_left ); // Initialize our next panel of B to be the current panel of B. const char* b2 = b1; @@ -286,115 +264,34 @@ void bli_gemmt_l_ker_var2 // Interior loop over the m dimension (MR rows at a time). for ( dim_t i = ir_start; i < ir_end; i += ir_inc ) { - const char* a1 = a_cast + i * rstep_a; - char* c11 = c1 + i * rstep_c; - - // No need to compute the diagonal offset for the rectangular - // region. - //diagoffc_ij = diagoffc - (doff_t)j*NR + (doff_t)i*MR;*/ + // Compute the diagonal offset for the microtile at (i,j). + const doff_t diagoffc_ij = diagoffc_j + ( doff_t )i*MR; - dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); + // Compute the current microtile's length. + const dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) + ? MR : m_left ); - // Compute the addresses of the next panels of A and B. - const char* a2 = bli_gemmt_get_next_a_upanel( a1, rstep_a, ir_inc ); - if ( bli_is_last_iter( i, m_iter, ir_tid, ir_nt ) ) - { - a2 = a_cast; - b2 = bli_gemmt_get_next_b_upanel( b1, cstep_b, jr_inc ); - if ( bli_is_last_iter( j, n_iter, jr_tid, jr_nt ) ) - b2 = b_cast; - } - - // Save addresses of next panels of A and B to the auxinfo_t - // object. - bli_auxinfo_set_next_a( a2, &aux ); - bli_auxinfo_set_next_b( b2, &aux ); - - // If the diagonal intersects the current MR x NR submatrix, we + // If the diagonal intersects the current MR x NR microtile, we // compute it the temporary buffer and then add in the elements // on or below the diagonal. - // Otherwise, if the submatrix is strictly below the diagonal, + // Otherwise, if the microtile is strictly below the diagonal, // we compute and store as we normally would. // And if we're strictly above the diagonal, we do nothing and - // continue. + // continue on through the IR loop to consider the next MR x NR + // microtile. + if ( bli_intersects_diag_n( diagoffc_ij, m_cur, n_cur ) ) { - // Invoke the gemm micro-kernel. - gemm_ukr - ( - m_cur, - n_cur, - k, - ( void* )alpha_cast, - ( void* )a1, - ( void* )b1, - ( void* )beta_cast, - c11, rs_c, cs_c, - &aux, - ( cntx_t* )cntx - ); - } - } - } - - // If there is no triangular region, then we're done. - if ( n_iter_tri == 0 ) return; - - // Use round-robin assignment of micropanels to threads in the 2nd loop - // and the default (slab or rr) partitioning in the 1st loop for the - // remaining triangular region of C. - bli_thread_range_jrir_rr( thread, n_iter_tri, 1, FALSE, &jr_start, &jr_end, &jr_inc ); - - // Advance the start and end iteration offsets for the triangular region - // by the number of iterations used for the rectangular region. - jr_start += n_iter_rct; - jr_end += n_iter_rct; - - // Loop over the n dimension (NR columns at a time). - for ( dim_t j = jr_start; j < jr_end; j += jr_inc ) - { - const char* b1 = b_cast + j * cstep_b; - char* c1 = c_cast + j * cstep_c; - - dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); - - // Initialize our next panel of B to be the current panel of B. - const char* b2 = b1; - - // Interior loop over the m dimension (MR rows at a time). - for ( dim_t i = ir_start; i < ir_end; i += ir_inc ) - { - const char* a1 = a_cast + i * rstep_a; - char* c11 = c1 + i * rstep_c; + const char* a1 = a_cast + i * rstep_a; + char* c11 = c1 + i * rstep_c; - // Compute the diagonal offset for the submatrix at (i,j). - doff_t diagoffc_ij = diagoffc - (doff_t)j*NR + (doff_t)i*MR; + // Compute the addresses of the next panel of A. + const char* a2 = bli_gemmt_get_next_a_upanel( a1, rstep_a, ir_inc ); - dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); + // Save addresses of next panels of A and B to the auxinfo_t + // object. + bli_auxinfo_set_next_a( a2, &aux ); + bli_auxinfo_set_next_b( b2, &aux ); - // Compute the addresses of the next panels of A and B. - const char* a2 = bli_gemmt_get_next_a_upanel( a1, rstep_a, ir_inc ); - if ( bli_is_last_iter( i, m_iter, ir_tid, ir_nt ) ) - { - a2 = a_cast; - b2 = bli_gemmt_get_next_b_upanel( b1, cstep_b, jr_inc ); - if ( bli_is_last_iter_rr( j, n_iter, jr_tid, jr_nt ) ) - b2 = b_cast; - } - - // Save addresses of next panels of A and B to the auxinfo_t - // object. - bli_auxinfo_set_next_a( a2, &aux ); - bli_auxinfo_set_next_b( b2, &aux ); - - // If the diagonal intersects the current MR x NR submatrix, we - // compute it the temporary buffer and then add in the elements - // on or below the diagonal. - // Otherwise, if the submatrix is strictly below the diagonal, - // we compute and store as we normally would. - // And if we're strictly above the diagonal, we do nothing and - // continue. - if ( bli_intersects_diag_n( diagoffc_ij, m_cur, n_cur ) ) - { // Invoke the gemm micro-kernel. gemm_ukr ( @@ -411,14 +308,35 @@ void bli_gemmt_l_ker_var2 ); // Scale C and add the result to only the stored part. - xpbys_mxn_l_ukr( diagoffc_ij, - m_cur, n_cur, - ct, rs_ct, cs_ct, - ( void* )beta_cast, - c11, rs_c, cs_c ); + xpbys_mxn_l_ukr + ( + diagoffc_ij, + m_cur, n_cur, + ct, rs_ct, cs_ct, + ( void* )beta_cast, + c11, rs_c, cs_c + ); } else if ( bli_is_strictly_below_diag_n( diagoffc_ij, m_cur, n_cur ) ) { + const char* a1 = a_cast + i * rstep_a; + char* c11 = c1 + i * rstep_c; + + // Compute the addresses of the next panels of A and B. + const char* a2 = bli_gemmt_get_next_a_upanel( a1, rstep_a, ir_inc ); + if ( bli_is_last_iter_l( i, m_iter, ir_tid, ir_nt ) ) + { + a2 = bli_gemmt_l_wrap_a_upanel( a_cast, rstep_a, diagoffc_j, MR, NR ); + b2 = bli_gemmt_get_next_b_upanel( b1, cstep_b, jr_inc ); + if ( bli_is_last_iter_slrr( j, n_iter, jr_tid, jr_nt ) ) + b2 = b_cast; + } + + // Save addresses of next panels of A and B to the auxinfo_t + // object. + bli_auxinfo_set_next_a( a2, &aux ); + bli_auxinfo_set_next_b( b2, &aux ); + // Invoke the gemm micro-kernel. gemm_ukr ( diff --git a/frame/3/gemmt/bli_gemmt_l_ker_var2b.c b/frame/3/gemmt/bli_gemmt_l_ker_var2b.c new file mode 100644 index 0000000000..7c50a4a540 --- /dev/null +++ b/frame/3/gemmt/bli_gemmt_l_ker_var2b.c @@ -0,0 +1,387 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +typedef void (*xpbys_mxn_l_vft) + ( + doff_t diagoff, + dim_t m, + dim_t n, + void* x, inc_t rs_x, inc_t cs_x, + void* b, + void* y, inc_t rs_y, inc_t cs_y + ); + +#undef GENTFUNC +#define GENTFUNC(ctype,ch,op) \ +\ +BLIS_INLINE void PASTEMAC(ch,op) \ + ( \ + doff_t diagoff, \ + dim_t m, \ + dim_t n, \ + void* x, inc_t rs_x, inc_t cs_x, \ + void* b, \ + void* y, inc_t rs_y, inc_t cs_y \ + ) \ +{ \ + ctype* restrict x_cast = x; \ + ctype* restrict b_cast = b; \ + ctype* restrict y_cast = y; \ +\ + PASTEMAC3(ch,ch,ch,xpbys_mxn_l) \ + ( \ + diagoff, \ + m, n, \ + x_cast, rs_x, cs_x, \ + b_cast, \ + y_cast, rs_y, cs_y \ + ); \ +} + +INSERT_GENTFUNC_BASIC0(xpbys_mxn_l_fn); + +static xpbys_mxn_l_vft GENARRAY(xpbys_mxn_l, xpbys_mxn_l_fn); + + +void bli_gemmt_l_ker_var2b + ( + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntx_t* cntx, + const cntl_t* cntl, + thrinfo_t* thread_par + ) +{ + const num_t dt_exec = bli_obj_exec_dt( c ); + const num_t dt_c = bli_obj_dt( c ); + + doff_t diagoffc = bli_obj_diag_offset( c ); + + const pack_t schema_a = bli_obj_pack_schema( a ); + const pack_t schema_b = bli_obj_pack_schema( b ); + + dim_t m = bli_obj_length( c ); + dim_t n = bli_obj_width( c ); + dim_t k = bli_obj_width( a ); + + const void* buf_a = bli_obj_buffer_at_off( a ); + const inc_t is_a = bli_obj_imag_stride( a ); + const dim_t pd_a = bli_obj_panel_dim( a ); + const inc_t ps_a = bli_obj_panel_stride( a ); + + const void* buf_b = bli_obj_buffer_at_off( b ); + const inc_t is_b = bli_obj_imag_stride( b ); + const dim_t pd_b = bli_obj_panel_dim( b ); + const inc_t ps_b = bli_obj_panel_stride( b ); + + void* buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + // Detach and multiply the scalars attached to A and B. + obj_t scalar_a, scalar_b; + bli_obj_scalar_detach( a, &scalar_a ); + bli_obj_scalar_detach( b, &scalar_b ); + bli_mulsc( &scalar_a, &scalar_b ); + + // Grab the addresses of the internal scalar buffers for the scalar + // merged above and the scalar attached to C. + const void* buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); + const void* buf_beta = bli_obj_internal_scalar_buffer( c ); + + const siz_t dt_size = bli_dt_size( dt_exec ); + const siz_t dt_c_size = bli_dt_size( dt_c ); + + // Alias some constants to simpler names. + const dim_t MR = pd_a; + const dim_t NR = pd_b; + + // Query the context for the micro-kernel address and cast it to its + // function pointer type. + gemm_ukr_vft gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt_exec, BLIS_GEMM_UKR, cntx ); + xpbys_mxn_l_vft xpbys_mxn_l_ukr = xpbys_mxn_l[ dt_exec ]; + + // Temporary C buffer for edge cases. Note that the strides of this + // temporary buffer are set so that they match the storage of the + // original C matrix. For example, if C is column-stored, ct will be + // column-stored as well. + char ct[ BLIS_STACK_BUF_MAX_SIZE ] + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); + const bool col_pref = bli_cntx_ukr_prefers_cols_dt( dt_exec, BLIS_GEMM_VIR_UKR, cntx ); + const inc_t rs_ct = ( col_pref ? 1 : NR ); + const inc_t cs_ct = ( col_pref ? MR : 1 ); + + const void* zero = bli_obj_buffer_for_const( dt_exec, &BLIS_ZERO ); + const char* a_cast = buf_a; + const char* b_cast = buf_b; + char* c_cast = buf_c; + const char* alpha_cast = buf_alpha; + const char* beta_cast = buf_beta; + + /* + Assumptions/assertions: + rs_a == 1 + cs_a == PACKMR + pd_a == MR + ps_a == stride to next micro-panel of A + rs_b == PACKNR + cs_b == 1 + pd_b == NR + ps_b == stride to next micro-panel of B + rs_c == (no assumptions) + cs_c == (no assumptions) + */ + + // If any dimension is zero, return immediately. + if ( bli_zero_dim3( m, n, k ) ) return; + + // Safeguard: If the current panel of C is entirely above the diagonal, + // it is not stored. So we do nothing. + if ( bli_is_strictly_above_diag_n( diagoffc, m, n ) ) return; + + // If there is a zero region above where the diagonal of C intersects + // the left edge of the panel, adjust the pointer to C and A and treat + // this case as if the diagonal offset were zero. + // NOTE: It's possible that after this pruning that the diagonal offset + // is still negative (though its absolute value is guaranteed to be less + // than MR). + if ( diagoffc < 0 ) + { + const dim_t ip = -diagoffc / MR; + const dim_t i = ip * MR; + + m = m - i; + diagoffc = diagoffc % MR; + c_cast = c_cast + (i )*rs_c*dt_c_size; + a_cast = a_cast + (ip )*ps_a*dt_size; + } + + // If there is a zero region to the right of where the diagonal + // of C intersects the bottom of the panel, shrink it to prevent + // "no-op" iterations from executing. + if ( diagoffc + m < n ) + { + n = diagoffc + m; + } + + // Compute number of primary and leftover components of the m and n + // dimensions. + const dim_t n_iter = n / NR + ( n % NR ? 1 : 0 ); + const dim_t n_left = n % NR; + + const dim_t m_iter = m / MR + ( m % MR ? 1 : 0 ); + const dim_t m_left = m % MR; + + // Determine some increments used to step through A, B, and C. + const inc_t rstep_a = ps_a * dt_size; + + const inc_t cstep_b = ps_b * dt_size; + + const inc_t rstep_c = rs_c * MR * dt_c_size; + const inc_t cstep_c = cs_c * NR * dt_c_size; + + auxinfo_t aux; + + // Save the pack schemas of A and B to the auxinfo_t object. + bli_auxinfo_set_schema_a( schema_a, &aux ); + bli_auxinfo_set_schema_b( schema_b, &aux ); + + // Save the imaginary stride of A and B to the auxinfo_t object. + bli_auxinfo_set_is_a( is_a, &aux ); + bli_auxinfo_set_is_b( is_b, &aux ); + + // The 'thread' argument points to the thrinfo_t node for the 2nd (jr) + // loop around the microkernel. Here we query the thrinfo_t node for the + // 1st (ir) loop around the microkernel. + thrinfo_t* thread = bli_thrinfo_sub_node( thread_par ); + //thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); + + const dim_t jr_nt = bli_thrinfo_n_way( thread ); + const dim_t jr_tid = bli_thrinfo_work_id( thread ); + //const dim_t ir_nt = bli_thrinfo_n_way( caucus ); + //const dim_t ir_tid = bli_thrinfo_work_id( caucus ); + + // Determine the starting microtile offsets and number of microtiles to + // compute for each thread. Note that assignment of microtiles is done + // according to the tlb policy. + dim_t jr_st, ir_st; + const dim_t n_ut_for_me + = + bli_thread_range_tlb_l( jr_nt, jr_tid, diagoffc, m_iter, n_iter, MR, NR, + &jr_st, &ir_st ); + + // It's possible that there are so few microtiles relative to the number + // of threads that one or more threads gets no work. If that happens, those + // threads can return early. + if ( n_ut_for_me == 0 ) return; + + // Start the jr/ir loops with the current thread's microtile offsets computed + // by bli_thread_range_tlb(). + dim_t i = ir_st; + dim_t j = jr_st; + + // Initialize a counter to track the number of microtiles computed by the + // current thread. + dim_t ut = 0; + + // Loop over the n dimension (NR columns at a time). + for ( ; true; ++j ) + { + const char* b1 = b_cast + j * cstep_b; + char* c1 = c_cast + j * cstep_c; + + // Compute the diagonal offset for the column of microtiles at (0,j). + const doff_t diagoffc_j = diagoffc - ( doff_t )j*NR; + + // Compute the current microtile's width. + const dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) + ? NR : n_left ); + + // Initialize our next panel of B to be the current panel of B. + const char* b2 = b1; + + bli_auxinfo_set_next_b( b2, &aux ); + + // Interior loop over the m dimension (MR rows at a time). + for ( ; i < m_iter; ++i ) + { + // Compute the diagonal offset for the microtile at (i,j). + const doff_t diagoffc_ij = diagoffc_j + ( doff_t )i*MR; + + // Compute the current microtile's length. + const dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) + ? MR : m_left ); + + // If the diagonal intersects the current MR x NR microtile, we + // compute it the temporary buffer and then add in the elements + // on or below the diagonal. + // Otherwise, if the microtile is strictly below the diagonal, + // we compute and store as we normally would. + // And if we're strictly above the diagonal, we simply advance + // to the last microtile before the diagonal. + if ( bli_intersects_diag_n( diagoffc_ij, m_cur, n_cur ) ) + { + const char* a1 = a_cast + i * rstep_a; + char* c11 = c1 + i * rstep_c; + + // Compute the addresses of the next panel of A. + const char* a2 = bli_gemmt_get_next_a_upanel( a1, rstep_a, 1 ); + + // Save addresses of next panels of A and B to the auxinfo_t + // object. + bli_auxinfo_set_next_a( a2, &aux ); + + // Invoke the gemm micro-kernel. + gemm_ukr + ( + MR, + NR, + k, + ( void* )alpha_cast, + ( void* )a1, + ( void* )b1, + ( void* )zero, + ct, rs_ct, cs_ct, + &aux, + ( cntx_t* )cntx + ); + + // Scale C and add the result to only the stored part. + xpbys_mxn_l_ukr + ( + diagoffc_ij, + m_cur, n_cur, + ct, rs_ct, cs_ct, + ( void* )beta_cast, + c11, rs_c, cs_c + ); + + // Increment the microtile counter and check if the thread is done. + ut += 1; + if ( ut == n_ut_for_me ) return; + } + else if ( bli_is_strictly_below_diag_n( diagoffc_ij, m_cur, n_cur ) ) + { + const char* a1 = a_cast + i * rstep_a; + char* c11 = c1 + i * rstep_c; + + // Compute the addresses of the next panels of A and B. + const char* a2 = bli_gemmt_get_next_a_upanel( a1, rstep_a, 1 ); + if ( bli_is_last_iter_tlb_l( i, m_iter ) ) + { + a2 = bli_gemmt_l_wrap_a_upanel( a_cast, rstep_a, diagoffc_j, MR, NR ); + b2 = bli_gemmt_get_next_b_upanel( b1, cstep_b, 1 ); + bli_auxinfo_set_next_b( b2, &aux ); + } + + // Save addresses of next panels of A and B to the auxinfo_t + // object. + bli_auxinfo_set_next_a( a2, &aux ); + + // Invoke the gemm micro-kernel. + gemm_ukr + ( + m_cur, + n_cur, + k, + ( void* )alpha_cast, + ( void* )a1, + ( void* )b1, + ( void* )beta_cast, + c11, rs_c, cs_c, + &aux, + ( cntx_t* )cntx + ); + + // Increment the microtile counter and check if the thread is done. + ut += 1; + if ( ut == n_ut_for_me ) return; + } + else // if ( bli_is_strictly_above_diag_n( diagoffc_ij, m_cur, n_cur ) ) + { + // Skip ahead to the last microtile strictly above the diagonal. + i = -diagoffc_j / MR - 1; + } + } + + // Upon reaching the end of the column of microtiles, get ready to begin + // at the beginning of the next column (i.e., the next jr loop iteration). + i = 0; + } +} + diff --git a/frame/3/gemmt/bli_gemmt_u_ker_var2.c b/frame/3/gemmt/bli_gemmt_u_ker_var2.c index 5b4e1ccd96..78d5b869d2 100644 --- a/frame/3/gemmt/bli_gemmt_u_ker_var2.c +++ b/frame/3/gemmt/bli_gemmt_u_ker_var2.c @@ -48,7 +48,7 @@ typedef void (*xpbys_mxn_u_vft) #undef GENTFUNC #define GENTFUNC(ctype,ch,op) \ \ -void PASTEMAC(ch,op) \ +BLIS_INLINE void PASTEMAC(ch,op) \ ( \ doff_t diagoff, \ dim_t m, \ @@ -76,18 +76,19 @@ INSERT_GENTFUNC_BASIC0(xpbys_mxn_u_fn); static xpbys_mxn_u_vft GENARRAY(xpbys_mxn_u, xpbys_mxn_u_fn); + void bli_gemmt_u_ker_var2 ( - const obj_t* a, - const obj_t* b, - const obj_t* c, - const cntx_t* cntx, - const cntl_t* cntl, + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntx_t* cntx, + const cntl_t* cntl, thrinfo_t* thread_par ) { - const num_t dt = bli_obj_exec_dt( c ); - const dim_t dt_size = bli_dt_size( dt ); + const num_t dt_exec = bli_obj_exec_dt( c ); + const num_t dt_c = bli_obj_dt( c ); doff_t diagoffc = bli_obj_diag_offset( c ); @@ -113,7 +114,7 @@ void bli_gemmt_u_ker_var2 const inc_t cs_c = bli_obj_col_stride( c ); // Detach and multiply the scalars attached to A and B. - obj_t scalar_a, scalar_b; + obj_t scalar_a, scalar_b; bli_obj_scalar_detach( a, &scalar_a ); bli_obj_scalar_detach( b, &scalar_b ); bli_mulsc( &scalar_a, &scalar_b ); @@ -123,14 +124,17 @@ void bli_gemmt_u_ker_var2 const void* buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); const void* buf_beta = bli_obj_internal_scalar_buffer( c ); + const siz_t dt_size = bli_dt_size( dt_exec ); + const siz_t dt_c_size = bli_dt_size( dt_c ); + // Alias some constants to simpler names. const dim_t MR = pd_a; const dim_t NR = pd_b; // Query the context for the micro-kernel address and cast it to its // function pointer type. - gemm_ukr_vft gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); - xpbys_mxn_u_vft xpbys_mxn_u_ukr = xpbys_mxn_u[ dt ]; + gemm_ukr_vft gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt_exec, BLIS_GEMM_UKR, cntx ); + xpbys_mxn_u_vft xpbys_mxn_u_ukr = xpbys_mxn_u[ dt_exec ]; // Temporary C buffer for edge cases. Note that the strides of this // temporary buffer are set so that they match the storage of the @@ -138,11 +142,11 @@ void bli_gemmt_u_ker_var2 // column-stored as well. char ct[ BLIS_STACK_BUF_MAX_SIZE ] __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); - const bool col_pref = bli_cntx_ukr_prefers_cols_dt( dt, BLIS_GEMM_VIR_UKR, cntx ); + const bool col_pref = bli_cntx_ukr_prefers_cols_dt( dt_exec, BLIS_GEMM_VIR_UKR, cntx ); const inc_t rs_ct = ( col_pref ? 1 : NR ); const inc_t cs_ct = ( col_pref ? MR : 1 ); - const void* zero = bli_obj_buffer_for_const( dt, &BLIS_ZERO ); + const void* zero = bli_obj_buffer_for_const( dt_exec, &BLIS_ZERO ); const char* a_cast = buf_a; const char* b_cast = buf_b; char* c_cast = buf_c; @@ -177,12 +181,13 @@ void bli_gemmt_u_ker_var2 // is still positive (though it is guaranteed to be less than NR). if ( diagoffc > 0 ) { - dim_t jp = diagoffc / NR; - dim_t j = jp * NR; - n = n - j; - diagoffc = diagoffc % NR; - c_cast = c_cast + (j )*cs_c*dt_size; - b_cast = b_cast + (jp )*ps_b*dt_size; + const dim_t jp = diagoffc / NR; + const dim_t j = jp * NR; + + n = n - j; + diagoffc = diagoffc % NR; + c_cast = c_cast + (j )*cs_c*dt_c_size; + b_cast = b_cast + (jp )*ps_b*dt_size; } // If there is a zero region below where the diagonal of C intersects @@ -195,25 +200,23 @@ void bli_gemmt_u_ker_var2 // Compute number of primary and leftover components of the m and n // dimensions. - dim_t n_iter = n / NR; - dim_t n_left = n % NR; + const dim_t n_iter = n / NR + ( n % NR ? 1 : 0 ); + const dim_t n_left = n % NR; - dim_t m_iter = m / MR; - dim_t m_left = m % MR; - - if ( n_left ) ++n_iter; - if ( m_left ) ++m_iter; + const dim_t m_iter = m / MR + ( m % MR ? 1 : 0 ); + const dim_t m_left = m % MR; // Determine some increments used to step through A, B, and C. - inc_t rstep_a = ps_a * dt_size; + const inc_t rstep_a = ps_a * dt_size; - inc_t cstep_b = ps_b * dt_size; + const inc_t cstep_b = ps_b * dt_size; - inc_t rstep_c = rs_c * MR * dt_size; - inc_t cstep_c = cs_c * NR * dt_size; + const inc_t rstep_c = rs_c * MR * dt_c_size; + const inc_t cstep_c = cs_c * NR * dt_c_size; - // Save the pack schemas of A and B to the auxinfo_t object. auxinfo_t aux; + + // Save the pack schemas of A and B to the auxinfo_t object. bli_auxinfo_set_schema_a( schema_a, &aux ); bli_auxinfo_set_schema_b( schema_b, &aux ); @@ -221,9 +224,6 @@ void bli_gemmt_u_ker_var2 bli_auxinfo_set_is_a( is_a, &aux ); bli_auxinfo_set_is_b( is_b, &aux ); - // Save the desired output datatype (indicating no typecasting). - //bli_auxinfo_set_dt_on_output( dt, &aux );*/ - // The 'thread' argument points to the thrinfo_t node for the 2nd (jr) // loop around the microkernel. Here we query the thrinfo_t node for the // 1st (ir) loop around the microkernel. @@ -231,47 +231,21 @@ void bli_gemmt_u_ker_var2 thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); // Query the number of threads and thread ids for each loop. - dim_t jr_nt = bli_thrinfo_n_way( thread ); - dim_t jr_tid = bli_thrinfo_work_id( thread ); - dim_t ir_nt = bli_thrinfo_n_way( caucus ); - dim_t ir_tid = bli_thrinfo_work_id( caucus ); - - dim_t jr_start, jr_end; - dim_t ir_start, ir_end; - dim_t jr_inc, ir_inc; + const dim_t jr_nt = bli_thrinfo_n_way( thread ); + const dim_t jr_tid = bli_thrinfo_work_id( thread ); + //const dim_t ir_nt = bli_thrinfo_n_way( caucus ); + //const dim_t ir_tid = bli_thrinfo_work_id( caucus ); - // Note that we partition the 2nd loop into two regions: the triangular - // part of C, and the rectangular portion. - dim_t n_iter_tri; - dim_t n_iter_rct; + dim_t jr_start, jr_end, jr_inc; + dim_t ir_start, ir_end, ir_inc; - if ( bli_is_strictly_above_diag_n( diagoffc, m, n ) ) - { - // If the entire panel of C does not intersect the diagonal, there is - // no triangular region, and therefore we can skip the first set of - // loops. - n_iter_tri = 0; - n_iter_rct = n_iter; - } - else - { - // If the panel of C does intersect the diagonal, compute the number of - // iterations in the triangular (or trapezoidal) region by dividing NR - // into the number of rows in C. A non-zero remainder means we need to - // add one additional iteration. That is, we want the triangular region - // to contain as few columns of whole microtiles as possible while still - // including all microtiles that intersect the diagonal. The number of - // iterations in the rectangular region is computed as the remaining - // number of iterations in the n dimension. - n_iter_tri = ( m + diagoffc ) / NR + ( ( m + diagoffc ) % NR ? 1 : 0 ); - n_iter_rct = n_iter - n_iter_tri; - } - - // Use round-robin assignment of micropanels to threads in the 2nd loop - // and the default (slab or rr) partitioning in the 1st loop for the - // initial triangular region of C (if it exists). - bli_thread_range_jrir_rr( thread, n_iter_tri, 1, FALSE, &jr_start, &jr_end, &jr_inc ); - bli_thread_range_jrir ( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); + // Determine the thread range and increment for the 2nd and 1st loops. + // NOTE: The definition of bli_thread_range_slrr() will depend on whether + // slab or round-robin partitioning was requested at configure-time. + bli_thread_range_quad( thread, diagoffc, BLIS_UPPER, m, n, NR, + FALSE, &jr_start, &jr_end, &jr_inc ); + //bli_thread_range_slrr( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); + bli_thread_range_slrr( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); // Loop over the n dimension (NR columns at a time). for ( dim_t j = jr_start; j < jr_end; j += jr_inc ) @@ -279,7 +253,12 @@ void bli_gemmt_u_ker_var2 const char* b1 = b_cast + j * cstep_b; char* c1 = c_cast + j * cstep_c; - dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); + // Compute the diagonal offset for the column of microtiles at (0,j). + const doff_t diagoffc_j = diagoffc - ( doff_t )j*NR; + + // Compute the current microtile's width. + const dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) + ? NR : n_left ); // Initialize our next panel of B to be the current panel of B. const char* b2 = b1; @@ -287,38 +266,41 @@ void bli_gemmt_u_ker_var2 // Interior loop over the m dimension (MR rows at a time). for ( dim_t i = ir_start; i < ir_end; i += ir_inc ) { - const char* a1 = a_cast + i * rstep_a; - char* c11 = c1 + i * rstep_c; - - // Compute the diagonal offset for the submatrix at (i,j). - doff_t diagoffc_ij = diagoffc - (doff_t)j*NR + (doff_t)i*MR; + // Compute the diagonal offset for the microtile at (i,j). + const doff_t diagoffc_ij = diagoffc_j + ( doff_t )i*MR; - dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); + // Compute the current microtile's length. + const dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) + ? MR : m_left ); - // Compute the addresses of the next panels of A and B. - const char* a2 = bli_gemmt_get_next_a_upanel( a1, rstep_a, ir_inc ); - if ( bli_is_last_iter( i, m_iter, ir_tid, ir_nt ) ) - { - a2 = a_cast; - b2 = bli_gemmt_get_next_b_upanel( b1, cstep_b, jr_inc ); - if ( bli_is_last_iter_rr( j, n_iter, jr_tid, jr_nt ) ) - b2 = b_cast; - } - - // Save addresses of next panels of A and B to the auxinfo_t - // object. - bli_auxinfo_set_next_a( a2, &aux ); - bli_auxinfo_set_next_b( b2, &aux ); - - // If the diagonal intersects the current MR x NR submatrix, we + // If the diagonal intersects the current MR x NR microtile, we // compute it the temporary buffer and then add in the elements // on or below the diagonal. - // Otherwise, if the submatrix is strictly above the diagonal, + // Otherwise, if the microtile is strictly above the diagonal, // we compute and store as we normally would. // And if we're strictly below the diagonal, we do nothing and - // continue. + // continue on through the IR loop to consider the next MR x NR + // microtile. if ( bli_intersects_diag_n( diagoffc_ij, m_cur, n_cur ) ) { + const char* a1 = a_cast + i * rstep_a; + char* c11 = c1 + i * rstep_c; + + // Compute the addresses of the next panels of A and B. + const char* a2 = bli_gemmt_get_next_a_upanel( a1, rstep_a, ir_inc ); + if ( bli_is_last_iter_u( diagoffc_ij, MR, NR, ir_inc ) ) + { + a2 = bli_gemmt_u_wrap_a_upanel( a_cast, rstep_a, diagoffc_j, MR, NR ); + b2 = bli_gemmt_get_next_b_upanel( b1, cstep_b, jr_inc ); + if ( bli_is_last_iter_slrr( j, n_iter, jr_tid, jr_nt ) ) + b2 = b_cast; + } + + // Save addresses of next panels of A and B to the auxinfo_t + // object. + bli_auxinfo_set_next_a( a2, &aux ); + bli_auxinfo_set_next_b( b2, &aux ); + // Invoke the gemm micro-kernel. gemm_ukr ( @@ -335,93 +317,28 @@ void bli_gemmt_u_ker_var2 ); // Scale C and add the result to only the stored part. - xpbys_mxn_u_ukr( diagoffc_ij, - m_cur, n_cur, - ct, rs_ct, cs_ct, - ( void* )beta_cast, - c11, rs_c, cs_c ); - } - else if ( bli_is_strictly_above_diag_n( diagoffc_ij, m_cur, n_cur ) ) - { - // Invoke the gemm micro-kernel. - gemm_ukr + xpbys_mxn_u_ukr ( - m_cur, - n_cur, - k, - ( void* )alpha_cast, - ( void* )a1, - ( void* )b1, + diagoffc_ij, + m_cur, n_cur, + ct, rs_ct, cs_ct, ( void* )beta_cast, - c11, rs_c, cs_c, - &aux, - ( cntx_t* )cntx + c11, rs_c, cs_c ); } - } - } - - // If there is no rectangular region, then we're done. - if ( n_iter_rct == 0 ) return; - - // Determine the thread range and increment for the 2nd loop of the - // remaining rectangular region of C (and also use default partitioning - // for the 1st loop). - // NOTE: The definition of bli_thread_range_jrir() will depend on whether - // slab or round-robin partitioning was requested at configure-time. - bli_thread_range_jrir( thread, n_iter_rct, 1, FALSE, &jr_start, &jr_end, &jr_inc ); - - // Advance the start and end iteration offsets for the rectangular region - // by the number of iterations used for the triangular region. - jr_start += n_iter_tri; - jr_end += n_iter_tri; - - // Loop over the n dimension (NR columns at a time). - for ( dim_t j = jr_start; j < jr_end; j += jr_inc ) - { - const char* b1 = b_cast + j * cstep_b; - char* c1 = c_cast + j * cstep_c; - - dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); - - // Initialize our next panel of B to be the current panel of B. - const char* b2 = b1; - - // Interior loop over the m dimension (MR rows at a time). - for ( dim_t i = ir_start; i < ir_end; i += ir_inc ) - { - const char* a1 = a_cast + i * rstep_a; - char* c11 = c1 + i * rstep_c; - - // No need to compute the diagonal offset for the rectangular - // region. - //diagoffc_ij = diagoffc - (doff_t)j*NR + (doff_t)i*MR;*/ - - dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); - - // Compute the addresses of the next panels of A and B. - const char* a2 = bli_gemmt_get_next_a_upanel( a1, rstep_a, ir_inc ); - if ( bli_is_last_iter( i, m_iter, ir_tid, ir_nt ) ) + else if ( bli_is_strictly_above_diag_n( diagoffc_ij, m_cur, n_cur ) ) { - a2 = a_cast; - b2 = bli_gemmt_get_next_b_upanel( b1, cstep_b, jr_inc ); - if ( bli_is_last_iter( j, n_iter, jr_tid, jr_nt ) ) - b2 = b_cast; - } + const char* a1 = a_cast + i * rstep_a; + char* c11 = c1 + i * rstep_c; - // Save addresses of next panels of A and B to the auxinfo_t - // object. - bli_auxinfo_set_next_a( a2, &aux ); - bli_auxinfo_set_next_b( b2, &aux ); + // Compute the addresses of the next panel of A. + const char* a2 = bli_gemmt_get_next_a_upanel( a1, rstep_a, ir_inc ); + + // Save addresses of next panels of A and B to the auxinfo_t + // object. + bli_auxinfo_set_next_a( a2, &aux ); + bli_auxinfo_set_next_b( b2, &aux ); - // If the diagonal intersects the current MR x NR submatrix, we - // compute it the temporary buffer and then add in the elements - // on or below the diagonal. - // Otherwise, if the submatrix is strictly above the diagonal, - // we compute and store as we normally would. - // And if we're strictly below the diagonal, we do nothing and - // continue. - { // Invoke the gemm micro-kernel. gemm_ukr ( diff --git a/frame/3/gemmt/bli_gemmt_u_ker_var2b.c b/frame/3/gemmt/bli_gemmt_u_ker_var2b.c new file mode 100644 index 0000000000..91275577a4 --- /dev/null +++ b/frame/3/gemmt/bli_gemmt_u_ker_var2b.c @@ -0,0 +1,386 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +typedef void (*xpbys_mxn_u_vft) + ( + doff_t diagoff, + dim_t m, + dim_t n, + void* x, inc_t rs_x, inc_t cs_x, + void* b, + void* y, inc_t rs_y, inc_t cs_y + ); + +#undef GENTFUNC +#define GENTFUNC(ctype,ch,op) \ +\ +BLIS_INLINE void PASTEMAC(ch,op) \ + ( \ + doff_t diagoff, \ + dim_t m, \ + dim_t n, \ + void* x, inc_t rs_x, inc_t cs_x, \ + void* b, \ + void* y, inc_t rs_y, inc_t cs_y \ + ) \ +{ \ + ctype* restrict x_cast = x; \ + ctype* restrict b_cast = b; \ + ctype* restrict y_cast = y; \ +\ + PASTEMAC3(ch,ch,ch,xpbys_mxn_u) \ + ( \ + diagoff, \ + m, n, \ + x_cast, rs_x, cs_x, \ + b_cast, \ + y_cast, rs_y, cs_y \ + ); \ +} + +INSERT_GENTFUNC_BASIC0(xpbys_mxn_u_fn); + +static xpbys_mxn_u_vft GENARRAY(xpbys_mxn_u, xpbys_mxn_u_fn); + + +void bli_gemmt_u_ker_var2b + ( + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntx_t* cntx, + const cntl_t* cntl, + thrinfo_t* thread_par + ) +{ + const num_t dt_exec = bli_obj_exec_dt( c ); + const num_t dt_c = bli_obj_dt( c ); + + doff_t diagoffc = bli_obj_diag_offset( c ); + + const pack_t schema_a = bli_obj_pack_schema( a ); + const pack_t schema_b = bli_obj_pack_schema( b ); + + dim_t m = bli_obj_length( c ); + dim_t n = bli_obj_width( c ); + dim_t k = bli_obj_width( a ); + + const void* buf_a = bli_obj_buffer_at_off( a ); + const inc_t is_a = bli_obj_imag_stride( a ); + const dim_t pd_a = bli_obj_panel_dim( a ); + const inc_t ps_a = bli_obj_panel_stride( a ); + + const void* buf_b = bli_obj_buffer_at_off( b ); + const inc_t is_b = bli_obj_imag_stride( b ); + const dim_t pd_b = bli_obj_panel_dim( b ); + const inc_t ps_b = bli_obj_panel_stride( b ); + + void* buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + // Detach and multiply the scalars attached to A and B. + obj_t scalar_a, scalar_b; + bli_obj_scalar_detach( a, &scalar_a ); + bli_obj_scalar_detach( b, &scalar_b ); + bli_mulsc( &scalar_a, &scalar_b ); + + // Grab the addresses of the internal scalar buffers for the scalar + // merged above and the scalar attached to C. + const void* buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); + const void* buf_beta = bli_obj_internal_scalar_buffer( c ); + + const siz_t dt_size = bli_dt_size( dt_exec ); + const siz_t dt_c_size = bli_dt_size( dt_c ); + + // Alias some constants to simpler names. + const dim_t MR = pd_a; + const dim_t NR = pd_b; + + // Query the context for the micro-kernel address and cast it to its + // function pointer type. + gemm_ukr_vft gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt_exec, BLIS_GEMM_UKR, cntx ); + xpbys_mxn_u_vft xpbys_mxn_u_ukr = xpbys_mxn_u[ dt_exec ]; + + // Temporary C buffer for edge cases. Note that the strides of this + // temporary buffer are set so that they match the storage of the + // original C matrix. For example, if C is column-stored, ct will be + // column-stored as well. + char ct[ BLIS_STACK_BUF_MAX_SIZE ] + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); + const bool col_pref = bli_cntx_ukr_prefers_cols_dt( dt_exec, BLIS_GEMM_VIR_UKR, cntx ); + const inc_t rs_ct = ( col_pref ? 1 : NR ); + const inc_t cs_ct = ( col_pref ? MR : 1 ); + + const void* zero = bli_obj_buffer_for_const( dt_exec, &BLIS_ZERO ); + const char* a_cast = buf_a; + const char* b_cast = buf_b; + char* c_cast = buf_c; + const char* alpha_cast = buf_alpha; + const char* beta_cast = buf_beta; + + /* + Assumptions/assertions: + rs_a == 1 + cs_a == PACKMR + pd_a == MR + ps_a == stride to next micro-panel of A + rs_b == PACKNR + cs_b == 1 + pd_b == NR + ps_b == stride to next micro-panel of B + rs_c == (no assumptions) + cs_c == (no assumptions) + */ + + // If any dimension is zero, return immediately. + if ( bli_zero_dim3( m, n, k ) ) return; + + // Safeguard: If the current panel of C is entirely below the diagonal, + // it is not stored. So we do nothing. + if ( bli_is_strictly_below_diag_n( diagoffc, m, n ) ) return; + + // If there is a zero region to the left of where the diagonal of C + // intersects the top edge of the panel, adjust the pointer to C and B + // and treat this case as if the diagonal offset were zero. + // NOTE: It's possible that after this pruning that the diagonal offset + // is still positive (though it is guaranteed to be less than NR). + if ( diagoffc > 0 ) + { + const dim_t jp = diagoffc / NR; + const dim_t j = jp * NR; + + n = n - j; + diagoffc = diagoffc % NR; + c_cast = c_cast + (j )*cs_c*dt_c_size; + b_cast = b_cast + (jp )*ps_b*dt_size; + } + + // If there is a zero region below where the diagonal of C intersects + // the right edge of the panel, shrink it to prevent "no-op" iterations + // from executing. + if ( -diagoffc + n < m ) + { + m = -diagoffc + n; + } + + // Compute number of primary and leftover components of the m and n + // dimensions. + const dim_t n_iter = n / NR + ( n % NR ? 1 : 0 ); + const dim_t n_left = n % NR; + + const dim_t m_iter = m / MR + ( m % MR ? 1 : 0 ); + const dim_t m_left = m % MR; + + // Determine some increments used to step through A, B, and C. + const inc_t rstep_a = ps_a * dt_size; + + const inc_t cstep_b = ps_b * dt_size; + + const inc_t rstep_c = rs_c * MR * dt_c_size; + const inc_t cstep_c = cs_c * NR * dt_c_size; + + auxinfo_t aux; + + // Save the pack schemas of A and B to the auxinfo_t object. + bli_auxinfo_set_schema_a( schema_a, &aux ); + bli_auxinfo_set_schema_b( schema_b, &aux ); + + // Save the imaginary stride of A and B to the auxinfo_t object. + bli_auxinfo_set_is_a( is_a, &aux ); + bli_auxinfo_set_is_b( is_b, &aux ); + + // The 'thread' argument points to the thrinfo_t node for the 2nd (jr) + // loop around the microkernel. Here we query the thrinfo_t node for the + // 1st (ir) loop around the microkernel. + thrinfo_t* thread = bli_thrinfo_sub_node( thread_par ); + //thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); + + const dim_t jr_nt = bli_thrinfo_n_way( thread ); + const dim_t jr_tid = bli_thrinfo_work_id( thread ); + //const dim_t ir_nt = bli_thrinfo_n_way( caucus ); + //const dim_t ir_tid = bli_thrinfo_work_id( caucus ); + + // Determine the starting microtile offsets and number of microtiles to + // compute for each thread. Note that assignment of microtiles is done + // according to the tlb policy. + dim_t jr_st, ir_st; + const dim_t n_ut_for_me + = + bli_thread_range_tlb_u( jr_nt, jr_tid, diagoffc, m_iter, n_iter, MR, NR, + &jr_st, &ir_st ); + + // It's possible that there are so few microtiles relative to the number + // of threads that one or more threads gets no work. If that happens, those + // threads can return early. + if ( n_ut_for_me == 0 ) return; + + // Start the jr/ir loops with the current thread's microtile offsets computed + // by bli_thread_range_tlb(). + dim_t i = ir_st; + dim_t j = jr_st; + + // Initialize a counter to track the number of microtiles computed by the + // current thread. + dim_t ut = 0; + + // Loop over the n dimension (NR columns at a time). + for ( ; true; ++j ) + { + const char* b1 = b_cast + j * cstep_b; + char* c1 = c_cast + j * cstep_c; + + // Compute the diagonal offset for the column of microtiles at (0,j). + const doff_t diagoffc_j = diagoffc - ( doff_t )j*NR; + + // Compute the current microtile's width. + const dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) + ? NR : n_left ); + + // Initialize our next panel of B to be the current panel of B. + const char* b2 = b1; + + bli_auxinfo_set_next_b( b2, &aux ); + + // Interior loop over the m dimension (MR rows at a time). + for ( ; i < m_iter; ++i ) + { + // Compute the diagonal offset for the microtile at (i,j). + const doff_t diagoffc_ij = diagoffc_j + ( doff_t )i*MR; + + // Compute the current microtile's length. + const dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) + ? MR : m_left ); + + // If the diagonal intersects the current MR x NR microtile, we + // compute it the temporary buffer and then add in the elements + // on or below the diagonal. + // Otherwise, if the microtile is strictly above the diagonal, + // we compute and store as we normally would. + // And if we're strictly below the diagonal, we simply advance + // to last microtile before the bottom of the matrix. + if ( bli_intersects_diag_n( diagoffc_ij, m_cur, n_cur ) ) + { + const char* a1 = a_cast + i * rstep_a; + char* c11 = c1 + i * rstep_c; + + // Compute the addresses of the next panels of A and B. + const char* a2 = bli_gemmt_get_next_a_upanel( a1, rstep_a, 1 ); + if ( bli_is_last_iter_tlb_u( diagoffc_ij, MR, NR ) ) + { + a2 = bli_gemmt_u_wrap_a_upanel( a_cast, rstep_a, diagoffc_j, MR, NR ); + b2 = bli_gemmt_get_next_b_upanel( b1, cstep_b, 1 ); + bli_auxinfo_set_next_b( b2, &aux ); + } + + // Save addresses of next panels of A and B to the auxinfo_t + // object. + bli_auxinfo_set_next_a( a2, &aux ); + + // Invoke the gemm micro-kernel. + gemm_ukr + ( + MR, + NR, + k, + ( void* )alpha_cast, + ( void* )a1, + ( void* )b1, + ( void* )zero, + ct, rs_ct, cs_ct, + &aux, + ( cntx_t* )cntx + ); + + // Scale C and add the result to only the stored part. + xpbys_mxn_u_ukr + ( + diagoffc_ij, + m_cur, n_cur, + ct, rs_ct, cs_ct, + ( void* )beta_cast, + c11, rs_c, cs_c + ); + + // Increment the microtile counter and check if the thread is done. + ut += 1; + if ( ut == n_ut_for_me ) return; + } + else if ( bli_is_strictly_above_diag_n( diagoffc_ij, m_cur, n_cur ) ) + { + const char* a1 = a_cast + i * rstep_a; + char* c11 = c1 + i * rstep_c; + + // Compute the addresses of the next panel of A. + const char* a2 = bli_gemmt_get_next_a_upanel( a1, rstep_a, 1 ); + + // Save addresses of next panels of A and B to the auxinfo_t + // object. + bli_auxinfo_set_next_a( a2, &aux ); + + // Invoke the gemm micro-kernel. + gemm_ukr + ( + m_cur, + n_cur, + k, + ( void* )alpha_cast, + ( void* )a1, + ( void* )b1, + ( void* )beta_cast, + c11, rs_c, cs_c, + &aux, + ( cntx_t* )cntx + ); + + // Increment the microtile counter and check if the thread is done. + ut += 1; + if ( ut == n_ut_for_me ) return; + } + else // if ( bli_is_strictly_below_diag_n( diagoffc_ij, m_cur, n_cur ) ) + { + // Skip past the microtiles strictly below the diagonal. + i = m_iter - 1; + } + } + + // Upon reaching the end of the column of microtiles, get ready to begin + // at the beginning of the next column (i.e., the next jr loop iteration). + i = 0; + } +} + diff --git a/frame/3/gemmt/bli_gemmt_var.h b/frame/3/gemmt/bli_gemmt_var.h index eb6e160180..339b937555 100644 --- a/frame/3/gemmt/bli_gemmt_var.h +++ b/frame/3/gemmt/bli_gemmt_var.h @@ -43,46 +43,19 @@ \ void PASTEMAC0(opname) \ ( \ - const obj_t* a, \ - const obj_t* ah, \ - const obj_t* c, \ - const cntx_t* cntx, \ - const cntl_t* cntl, \ - thrinfo_t* thread \ + const obj_t* a, \ + const obj_t* ah, \ + const obj_t* c, \ + const cntx_t* cntx, \ + const cntl_t* cntl, \ + thrinfo_t* thread_par \ ); GENPROT( gemmt_x_ker_var2 ) - GENPROT( gemmt_l_ker_var2 ) GENPROT( gemmt_u_ker_var2 ) - -// -// Prototype BLAS-like interfaces with void pointer operands. -// - -#undef GENTPROT -#define GENTPROT( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - doff_t diagoffc, \ - pack_t schema_a, \ - pack_t schema_b, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - void* alpha, \ - void* a, inc_t cs_a, inc_t is_a, \ - dim_t pd_a, inc_t ps_a, \ - void* b, inc_t rs_b, inc_t is_b, \ - dim_t pd_b, inc_t ps_b, \ - void* beta, \ - void* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - thrinfo_t* thread \ - ); - -INSERT_GENTPROT_BASIC0( gemmt_l_ker_var2 ) -INSERT_GENTPROT_BASIC0( gemmt_u_ker_var2 ) +GENPROT( gemmt_x_ker_var2b ) +GENPROT( gemmt_l_ker_var2b ) +GENPROT( gemmt_u_ker_var2b ) diff --git a/frame/3/gemmt/bli_gemmt_x_ker_var2.c b/frame/3/gemmt/bli_gemmt_x_ker_var2.c index 207e1c938f..8081537b91 100644 --- a/frame/3/gemmt/bli_gemmt_x_ker_var2.c +++ b/frame/3/gemmt/bli_gemmt_x_ker_var2.c @@ -42,12 +42,12 @@ static l3_var_oft vars[2] = void bli_gemmt_x_ker_var2 ( - const obj_t* a, - const obj_t* ah, - const obj_t* c, - const cntx_t* cntx, - const cntl_t* cntl, - thrinfo_t* thread + const obj_t* a, + const obj_t* ah, + const obj_t* c, + const cntx_t* cntx, + const cntl_t* cntl, + thrinfo_t* thread_par ) { dim_t uplo; @@ -67,7 +67,7 @@ void bli_gemmt_x_ker_var2 c, cntx, cntl, - thread + thread_par ); } diff --git a/frame/3/gemmt/bli_gemmt_x_ker_var2b.c b/frame/3/gemmt/bli_gemmt_x_ker_var2b.c new file mode 100644 index 0000000000..132d7c13a9 --- /dev/null +++ b/frame/3/gemmt/bli_gemmt_x_ker_var2b.c @@ -0,0 +1,73 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +static l3_var_oft vars[2] = +{ + bli_gemmt_l_ker_var2b, bli_gemmt_u_ker_var2b, +}; + +void bli_gemmt_x_ker_var2b + ( + const obj_t* a, + const obj_t* ah, + const obj_t* c, + const cntx_t* cntx, + const cntl_t* cntl, + thrinfo_t* thread_par + ) +{ + dim_t uplo; + + // Set a bool based on the uplo field of C's root object. + if ( bli_obj_root_is_lower( c ) ) uplo = 0; + else uplo = 1; + + // Index into the variant array to extract the correct function pointer. + l3_var_oft f = vars[uplo]; + + // Call the macrokernel. + f + ( + a, + ah, + c, + cntx, + cntl, + thread_par + ); +} + diff --git a/frame/3/gemmt/other/bli_gemmt_l_ker_var2.c.prev b/frame/3/gemmt/other/bli_gemmt_l_ker_var2.c.prev new file mode 100644 index 0000000000..aed0359ecb --- /dev/null +++ b/frame/3/gemmt/other/bli_gemmt_l_ker_var2.c.prev @@ -0,0 +1,507 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define FUNCPTR_T gemmt_fp + +typedef void (*FUNCPTR_T) + ( + doff_t diagoffc, + pack_t schema_a, + pack_t schema_b, + dim_t m, + dim_t n, + dim_t k, + void* alpha, + void* a, inc_t cs_a, inc_t is_a, + dim_t pd_a, inc_t ps_a, + void* b, inc_t rs_b, inc_t is_b, + dim_t pd_b, inc_t ps_b, + void* beta, + void* c, inc_t rs_c, inc_t cs_c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ); + +static FUNCPTR_T GENARRAY(ftypes,gemmt_l_ker_var2); + + +void bli_gemmt_l_ker_var2 + ( + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ) +{ + const num_t dt_exec = bli_obj_exec_dt( c ); + + const doff_t diagoffc = bli_obj_diag_offset( c ); + + const pack_t schema_a = bli_obj_pack_schema( a ); + const pack_t schema_b = bli_obj_pack_schema( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + const dim_t k = bli_obj_width( a ); + + const void* buf_a = bli_obj_buffer_at_off( a ); + const inc_t cs_a = bli_obj_col_stride( a ); + const inc_t is_a = bli_obj_imag_stride( a ); + const dim_t pd_a = bli_obj_panel_dim( a ); + const inc_t ps_a = bli_obj_panel_stride( a ); + + const void* buf_b = bli_obj_buffer_at_off( b ); + const inc_t rs_b = bli_obj_row_stride( b ); + const inc_t is_b = bli_obj_imag_stride( b ); + const dim_t pd_b = bli_obj_panel_dim( b ); + const inc_t ps_b = bli_obj_panel_stride( b ); + + void* buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + // Detach and multiply the scalars attached to A and B. + obj_t scalar_a, scalar_b; + bli_obj_scalar_detach( a, &scalar_a ); + bli_obj_scalar_detach( b, &scalar_b ); + bli_mulsc( &scalar_a, &scalar_b ); + + // Grab the addresses of the internal scalar buffers for the scalar + // merged above and the scalar attached to C. + const void* buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); + const void* buf_beta = bli_obj_internal_scalar_buffer( c ); + + // Index into the type combination array to extract the correct + // function pointer. + ftypes[dt_exec] + ( + diagoffc, + schema_a, + schema_b, + m, + n, + k, + ( void* )buf_alpha, + ( void* )buf_a, cs_a, is_a, + pd_a, ps_a, + ( void* )buf_b, rs_b, is_b, + pd_b, ps_b, + ( void* )buf_beta, + buf_c, rs_c, cs_c, + ( cntx_t* )cntx, + rntm, + thread + ); +} + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + doff_t diagoffc, \ + pack_t schema_a, \ + pack_t schema_b, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* alpha, \ + void* a, inc_t cs_a, inc_t is_a, \ + dim_t pd_a, inc_t ps_a, \ + void* b, inc_t rs_b, inc_t is_b, \ + dim_t pd_b, inc_t ps_b, \ + void* beta, \ + void* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx, \ + rntm_t* rntm, \ + thrinfo_t* thread \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Alias some constants to simpler names. */ \ + const dim_t MR = pd_a; \ + const dim_t NR = pd_b; \ + /*const dim_t PACKMR = cs_a;*/ \ + /*const dim_t PACKNR = rs_b;*/ \ +\ + /* Query the context for the micro-kernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemm_ukr_ft) \ + gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ +\ + /* Temporary C buffer for edge cases. Note that the strides of this + temporary buffer are set so that they match the storage of the + original C matrix. For example, if C is column-stored, ct will be + column-stored as well. */ \ + ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const bool col_pref = bli_cntx_ukr_prefers_cols_dt( dt, BLIS_GEMM_VIR_UKR, cntx ); \ + const inc_t rs_ct = ( col_pref ? 1 : NR ); \ + const inc_t cs_ct = ( col_pref ? MR : 1 ); \ +\ + ctype* restrict zero = PASTEMAC(ch,0); \ + ctype* restrict a_cast = a; \ + ctype* restrict b_cast = b; \ + ctype* restrict c_cast = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ + ctype* restrict b1; \ + ctype* restrict c1; \ +\ + doff_t diagoffc_ij; \ + dim_t m_iter, m_left; \ + dim_t n_iter, n_left; \ + dim_t m_cur; \ + dim_t n_cur; \ + dim_t i, j, ip; \ + inc_t rstep_a; \ + inc_t cstep_b; \ + inc_t rstep_c, cstep_c; \ + auxinfo_t aux; \ +\ + /* + Assumptions/assertions: + rs_a == 1 + cs_a == PACKMR + pd_a == MR + ps_a == stride to next micro-panel of A + rs_b == PACKNR + cs_b == 1 + pd_b == NR + ps_b == stride to next micro-panel of B + rs_c == (no assumptions) + cs_c == (no assumptions) + */ \ +\ + /* If any dimension is zero, return immediately. */ \ + if ( bli_zero_dim3( m, n, k ) ) return; \ +\ + /* Safeguard: If the current panel of C is entirely above the diagonal, + it is not stored. So we do nothing. */ \ + if ( bli_is_strictly_above_diag_n( diagoffc, m, n ) ) return; \ +\ + /* If there is a zero region above where the diagonal of C intersects + the left edge of the panel, adjust the pointer to C and A and treat + this case as if the diagonal offset were zero. */ \ + if ( diagoffc < 0 ) \ + { \ + ip = -diagoffc / MR; \ + i = ip * MR; \ + m = m - i; \ + diagoffc = -diagoffc % MR; \ + c_cast = c_cast + (i )*rs_c; \ + a_cast = a_cast + (ip )*ps_a; \ + } \ +\ + /* If there is a zero region to the right of where the diagonal + of C intersects the bottom of the panel, shrink it to prevent + "no-op" iterations from executing. */ \ + if ( diagoffc + m < n ) \ + { \ + n = diagoffc + m; \ + } \ +\ + /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ + PASTEMAC(ch,set0s_mxn)( MR, NR, \ + ct, rs_ct, cs_ct ); \ +\ + /* Compute number of primary and leftover components of the m and n + dimensions. */ \ + n_iter = n / NR; \ + n_left = n % NR; \ +\ + m_iter = m / MR; \ + m_left = m % MR; \ +\ + if ( n_left ) ++n_iter; \ + if ( m_left ) ++m_iter; \ +\ + /* Determine some increments used to step through A, B, and C. */ \ + rstep_a = ps_a; \ +\ + cstep_b = ps_b; \ +\ + rstep_c = rs_c * MR; \ + cstep_c = cs_c * NR; \ +\ + /* Save the pack schemas of A and B to the auxinfo_t object. */ \ + bli_auxinfo_set_schema_a( schema_a, &aux ); \ + bli_auxinfo_set_schema_b( schema_b, &aux ); \ +\ + /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ + bli_auxinfo_set_is_a( is_a, &aux ); \ + bli_auxinfo_set_is_b( is_b, &aux ); \ +\ + /* Save the desired output datatype (indicating no typecasting). */ \ + /*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \ +\ + /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) + loop around the microkernel. Here we query the thrinfo_t node for the + 1st (ir) loop around the microkernel. */ \ + thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \ +\ + /* Query the number of threads and thread ids for each loop. */ \ + dim_t jr_nt = bli_thread_n_way( thread ); \ + dim_t jr_tid = bli_thread_work_id( thread ); \ + dim_t ir_nt = bli_thread_n_way( caucus ); \ + dim_t ir_tid = bli_thread_work_id( caucus ); \ +\ + dim_t jr_start, jr_end; \ + dim_t ir_start, ir_end; \ + dim_t jr_inc, ir_inc; \ +\ + /* Note that we partition the 2nd loop into two regions: the rectangular + part of C, and the triangular portion. */ \ + dim_t n_iter_rct; \ + dim_t n_iter_tri; \ +\ + if ( bli_is_strictly_below_diag_n( diagoffc, m, n ) ) \ + { \ + /* If the entire panel of C does not intersect the diagonal, there is + no triangular region, and therefore we can skip the second set of + loops. */ \ + n_iter_rct = n_iter; \ + n_iter_tri = 0; \ + } \ + else \ + { \ + /* If the panel of C does intersect the diagonal, compute the number of + iterations in the rectangular region by dividing NR into the diagonal + offset. Any remainder from this integer division is discarded, which + is what we want. That is, we want the rectangular region to contain + as many columns of whole microtiles as possible without including any + microtiles that intersect the diagonal. The number of iterations in + the triangular (or trapezoidal) region is computed as the remaining + number of iterations in the n dimension. */ \ + n_iter_rct = diagoffc / NR; \ + n_iter_tri = n_iter - n_iter_rct; \ + } \ +\ + /* Determine the thread range and increment for the 2nd and 1st loops for + the initial rectangular region of C (if it exists). + NOTE: The definition of bli_thread_range_jrir() will depend on whether + slab or round-robin partitioning was requested at configure-time. */ \ + bli_thread_range_jrir( thread, n_iter_rct, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \ + bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \ +\ + /* Loop over the n dimension (NR columns at a time). */ \ + for ( j = jr_start; j < jr_end; j += jr_inc ) \ + { \ + ctype* restrict a1; \ + ctype* restrict c11; \ + ctype* restrict b2; \ +\ + b1 = b_cast + j * cstep_b; \ + c1 = c_cast + j * cstep_c; \ +\ + n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ +\ + /* Initialize our next panel of B to be the current panel of B. */ \ + b2 = b1; \ +\ + /* Interior loop over the m dimension (MR rows at a time). */ \ + for ( i = ir_start; i < ir_end; i += ir_inc ) \ + { \ + ctype* restrict a2; \ +\ + a1 = a_cast + i * rstep_a; \ + c11 = c1 + i * rstep_c; \ +\ + /* No need to compute the diagonal offset for the rectangular + region. */ \ + /*diagoffc_ij = diagoffc - (doff_t)j*NR + (doff_t)i*MR;*/ \ +\ + m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ +\ + /* Compute the addresses of the next panels of A and B. */ \ + a2 = bli_gemmt_get_next_a_upanel( a1, rstep_a, ir_inc ); \ + if ( bli_is_last_iter( i, m_iter, ir_tid, ir_nt ) ) \ + { \ + a2 = a_cast; \ + b2 = bli_gemmt_get_next_b_upanel( b1, cstep_b, jr_inc ); \ + if ( bli_is_last_iter( j, n_iter, jr_tid, jr_nt ) ) \ + b2 = b_cast; \ + } \ +\ + /* Save addresses of next panels of A and B to the auxinfo_t + object. */ \ + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ +\ + /* If the diagonal intersects the current MR x NR submatrix, we + compute it the temporary buffer and then add in the elements + on or below the diagonal. + Otherwise, if the submatrix is strictly below the diagonal, + we compute and store as we normally would. + And if we're strictly above the diagonal, we do nothing and + continue. */ \ + { \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + } \ + } \ +\ + /* If there is no triangular region, then we're done. */ \ + if ( n_iter_tri == 0 ) return; \ +\ + /* Use round-robin assignment of micropanels to threads in the 2nd loop + and the default (slab or rr) partitioning in the 1st loop for the + remaining triangular region of C. */ \ + bli_thread_range_jrir_rr( thread, n_iter_tri, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \ +\ + /* Advance the start and end iteration offsets for the triangular region + by the number of iterations used for the rectangular region. */ \ + jr_start += n_iter_rct; \ + jr_end += n_iter_rct; \ +\ + /* Loop over the n dimension (NR columns at a time). */ \ + for ( j = jr_start; j < jr_end; j += jr_inc ) \ + { \ + ctype* restrict a1; \ + ctype* restrict c11; \ + ctype* restrict b2; \ +\ + b1 = b_cast + j * cstep_b; \ + c1 = c_cast + j * cstep_c; \ +\ + n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ +\ + /* Initialize our next panel of B to be the current panel of B. */ \ + b2 = b1; \ +\ + /* Interior loop over the m dimension (MR rows at a time). */ \ + for ( i = ir_start; i < ir_end; i += ir_inc ) \ + { \ + ctype* restrict a2; \ +\ + a1 = a_cast + i * rstep_a; \ + c11 = c1 + i * rstep_c; \ +\ + /* Compute the diagonal offset for the submatrix at (i,j). */ \ + diagoffc_ij = diagoffc - (doff_t)j*NR + (doff_t)i*MR; \ +\ + m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ +\ + /* Compute the addresses of the next panels of A and B. */ \ + a2 = bli_gemmt_get_next_a_upanel( a1, rstep_a, ir_inc ); \ + if ( bli_is_last_iter( i, m_iter, ir_tid, ir_nt ) ) \ + { \ + a2 = a_cast; \ + b2 = bli_gemmt_get_next_b_upanel( b1, cstep_b, jr_inc ); \ + if ( bli_is_last_iter_rr( j, n_iter, jr_tid, jr_nt ) ) \ + b2 = b_cast; \ + } \ +\ + /* Save addresses of next panels of A and B to the auxinfo_t + object. */ \ + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ +\ + /* If the diagonal intersects the current MR x NR submatrix, we + compute it the temporary buffer and then add in the elements + on or below the diagonal. + Otherwise, if the submatrix is strictly below the diagonal, + we compute and store as we normally would. + And if we're strictly above the diagonal, we do nothing and + continue. */ \ + if ( bli_intersects_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ + { \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + MR, \ + NR, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + zero, \ + ct, rs_ct, cs_ct, \ + &aux, \ + cntx \ + ); \ +\ + /* Scale C and add the result to only the stored part. */ \ + PASTEMAC(ch,xpbys_mxn_l)( diagoffc_ij, \ + m_cur, n_cur, \ + ct, rs_ct, cs_ct, \ + beta_cast, \ + c11, rs_c, cs_c ); \ + } \ + else if ( bli_is_strictly_below_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ + { \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC0( gemmt_l_ker_var2 ) + diff --git a/frame/3/gemmt/other/bli_gemmt_l_ker_var2b.c.before b/frame/3/gemmt/other/bli_gemmt_l_ker_var2b.c.before new file mode 100644 index 0000000000..4285bd1356 --- /dev/null +++ b/frame/3/gemmt/other/bli_gemmt_l_ker_var2b.c.before @@ -0,0 +1,427 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define FUNCPTR_T gemmt_fp + +typedef void (*FUNCPTR_T) + ( + doff_t diagoffc, + pack_t schema_a, + pack_t schema_b, + dim_t m, + dim_t n, + dim_t k, + void* alpha, + void* a, inc_t cs_a, inc_t is_a, + dim_t pd_a, inc_t ps_a, + void* b, inc_t rs_b, inc_t is_b, + dim_t pd_b, inc_t ps_b, + void* beta, + void* c, inc_t rs_c, inc_t cs_c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ); + +static FUNCPTR_T GENARRAY(ftypes,gemmt_l_ker_var2b); + + +void bli_gemmt_l_ker_var2b + ( + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ) +{ + const num_t dt_exec = bli_obj_exec_dt( c ); + + const doff_t diagoffc = bli_obj_diag_offset( c ); + + const pack_t schema_a = bli_obj_pack_schema( a ); + const pack_t schema_b = bli_obj_pack_schema( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + const dim_t k = bli_obj_width( a ); + + const void* buf_a = bli_obj_buffer_at_off( a ); + const inc_t cs_a = bli_obj_col_stride( a ); + const inc_t is_a = bli_obj_imag_stride( a ); + const dim_t pd_a = bli_obj_panel_dim( a ); + const inc_t ps_a = bli_obj_panel_stride( a ); + + const void* buf_b = bli_obj_buffer_at_off( b ); + const inc_t rs_b = bli_obj_row_stride( b ); + const inc_t is_b = bli_obj_imag_stride( b ); + const dim_t pd_b = bli_obj_panel_dim( b ); + const inc_t ps_b = bli_obj_panel_stride( b ); + + void* buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + // Detach and multiply the scalars attached to A and B. + obj_t scalar_a, scalar_b; + bli_obj_scalar_detach( a, &scalar_a ); + bli_obj_scalar_detach( b, &scalar_b ); + bli_mulsc( &scalar_a, &scalar_b ); + + // Grab the addresses of the internal scalar buffers for the scalar + // merged above and the scalar attached to C. + const void* buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); + const void* buf_beta = bli_obj_internal_scalar_buffer( c ); + + // Index into the type combination array to extract the correct + // function pointer. + ftypes[dt_exec] + ( + diagoffc, + schema_a, + schema_b, + m, + n, + k, + ( void* )buf_alpha, + ( void* )buf_a, cs_a, is_a, + pd_a, ps_a, + ( void* )buf_b, rs_b, is_b, + pd_b, ps_b, + ( void* )buf_beta, + buf_c, rs_c, cs_c, + ( cntx_t* )cntx, + rntm, + thread + ); +} + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + doff_t diagoffc, \ + pack_t schema_a, \ + pack_t schema_b, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* alpha, \ + void* a, inc_t cs_a, inc_t is_a, \ + dim_t pd_a, inc_t ps_a, \ + void* b, inc_t rs_b, inc_t is_b, \ + dim_t pd_b, inc_t ps_b, \ + void* beta, \ + void* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx, \ + rntm_t* rntm, \ + thrinfo_t* thread \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Alias some constants to simpler names. */ \ + const dim_t MR = pd_a; \ + const dim_t NR = pd_b; \ + /*const dim_t PACKMR = cs_a;*/ \ + /*const dim_t PACKNR = rs_b;*/ \ +\ + /* Query the context for the micro-kernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemm_ukr_ft) \ + gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ +\ + /* Temporary C buffer for edge cases. Note that the strides of this + temporary buffer are set so that they match the storage of the + original C matrix. For example, if C is column-stored, ct will be + column-stored as well. */ \ + ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const bool col_pref = bli_cntx_ukr_prefers_cols_dt( dt, BLIS_GEMM_VIR_UKR, cntx ); \ + const inc_t rs_ct = ( col_pref ? 1 : NR ); \ + const inc_t cs_ct = ( col_pref ? MR : 1 ); \ +\ + ctype* restrict zero = PASTEMAC(ch,0); \ + ctype* restrict a_cast = a; \ + ctype* restrict b_cast = b; \ + ctype* restrict c_cast = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ +\ + auxinfo_t aux; \ +\ + /* + Assumptions/assertions: + rs_a == 1 + cs_a == PACKMR + pd_a == MR + ps_a == stride to next micro-panel of A + rs_b == PACKNR + cs_b == 1 + pd_b == NR + ps_b == stride to next micro-panel of B + rs_c == (no assumptions) + cs_c == (no assumptions) + */ \ +\ + /* If any dimension is zero, return immediately. */ \ + if ( bli_zero_dim3( m, n, k ) ) return; \ +\ + /* Safeguard: If the current panel of C is entirely above the diagonal, + it is not stored. So we do nothing. */ \ + if ( bli_is_strictly_above_diag_n( diagoffc, m, n ) ) return; \ +\ + /* If there is a zero region above where the diagonal of C intersects + the left edge of the panel, adjust the pointer to C and A and treat + this case as if the diagonal offset were zero. + NOTE: It's possible that after this pruning that the diagonal offset + is still negative (though its absolute value is guaranteed to be less + than MR). */ \ + if ( diagoffc < 0 ) \ + { \ + const dim_t ip = -diagoffc / MR; \ + const dim_t i = ip * MR; \ +\ + m = m - i; \ + diagoffc = diagoffc % MR; \ + c_cast = c_cast + (i )*rs_c; \ + a_cast = a_cast + (ip )*ps_a; \ + } \ +\ + /* If there is a zero region to the right of where the diagonal + of C intersects the bottom of the panel, shrink it to prevent + "no-op" iterations from executing. */ \ + if ( diagoffc + m < n ) \ + { \ + n = diagoffc + m; \ + } \ +\ + /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ + PASTEMAC(ch,set0s_mxn)( MR, NR, \ + ct, rs_ct, cs_ct ); \ +\ + /* Compute number of primary and leftover components of the m and n + dimensions. */ \ + const dim_t n_iter = n / NR + ( n % NR ? 1 : 0 ); \ + const dim_t n_left = n % NR; \ +\ + const dim_t m_iter = m / MR + ( m % MR ? 1 : 0 ); \ + const dim_t m_left = m % MR; \ +\ + /* Determine some increments used to step through A, B, and C. */ \ + const inc_t rstep_a = ps_a; \ +\ + const inc_t cstep_b = ps_b; \ +\ + const inc_t rstep_c = rs_c * MR; \ + const inc_t cstep_c = cs_c * NR; \ +\ + /* Save the pack schemas of A and B to the auxinfo_t object. */ \ + bli_auxinfo_set_schema_a( schema_a, &aux ); \ + bli_auxinfo_set_schema_b( schema_b, &aux ); \ +\ + /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ + bli_auxinfo_set_is_a( is_a, &aux ); \ + bli_auxinfo_set_is_b( is_b, &aux ); \ +\ + /* Save the virtual microkernel address and the params. */ \ + /*bli_auxinfo_set_ukr( gemm_ukr, &aux );*/ \ + /*bli_auxinfo_set_params( params, &aux );*/ \ +\ + /* Save the desired output datatype (indicating no typecasting). */ \ + /*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \ +\ + /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) + loop around the microkernel. Here we query the thrinfo_t node for the + 1st (ir) loop around the microkernel. */ \ + thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \ +\ + /* Query the number of threads and thread ids for each loop. */ \ + const dim_t jr_nt = bli_thread_n_way( thread ); \ + const dim_t jr_tid = bli_thread_work_id( thread ); \ + const dim_t ir_nt = bli_thread_n_way( caucus ); \ + const dim_t ir_tid = bli_thread_work_id( caucus ); \ +\ + dim_t jr_start, jr_end; \ + dim_t ir_start, ir_end; \ + dim_t jr_inc, ir_inc; \ +\ + /* Determine the thread range and increment for the 2nd and 1st loops. + NOTE: The definition of bli_thread_range_jrir() will depend on whether + slab or round-robin partitioning was requested at configure-time. */ \ +/* +*/ \ + bli_thread_range_weighted_jr( thread, diagoffc, BLIS_LOWER, m, n, NR, \ + FALSE, &jr_start, &jr_end, &jr_inc ); \ + /*bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc );*/ \ +/* +*/ \ + bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \ +/* + dim_t jr_st, ir_st; \ + const dim_t n_ut_for_me \ + = \ + bli_thread_range_tlb( thread, diagoffc, BLIS_LOWER, m, n, MR, NR, \ + &jr_st, &ir_st ); \ +*/ \ +\ +/* +printf( "bli_gemmt_l_ker_var2b(): tid %d: m n = %d %d st en in = %3d %3d %3d do %d\n", (int)jr_tid, (int)m, (int)n, (int)jr_start, (int)jr_end, (int)jr_inc, (int)diagoffc ); \ +*/ \ +\ + /* Loop over the n dimension (NR columns at a time). */ \ + for ( dim_t j = jr_start; j < jr_end; j += jr_inc ) \ + { \ + ctype* restrict b1 = b_cast + j * cstep_b; \ + ctype* restrict c1 = c_cast + j * cstep_c; \ +\ + /* Compute the diagonal offset for the column of microtiles at (0,j). */ \ + const doff_t diagoffc_j = diagoffc - (doff_t)j*NR; \ + const dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) \ + ? NR : n_left ); \ +\ + /* Initialize our next panel of B to be the current panel of B. */ \ + ctype* restrict b2 = b1; \ +\ + /* Interior loop over the m dimension (MR rows at a time). */ \ + for ( dim_t i = ir_start; i < ir_end; i += ir_inc ) \ + { \ + /* Compute the diagonal offset for the microtile at (i,j). */ \ + const doff_t diagoffc_ij = diagoffc_j + (doff_t)i*MR; \ + const dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) \ + ? MR : m_left ); \ +\ + /* If the diagonal intersects the current MR x NR microtile, we + compute it the temporary buffer and then add in the elements + on or below the diagonal. + Otherwise, if the microtile is strictly below the diagonal, + we compute and store as we normally would. + And if we're strictly above the diagonal, we do nothing and + continue on through the IR loop to consider the next MR x NR + microtile. */ \ + if ( bli_intersects_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ + { \ + ctype* restrict a1 = a_cast + i * rstep_a; \ + ctype* restrict c11 = c1 + i * rstep_c; \ +\ + /* Compute the addresses of the next panels of A and B. */ \ + ctype* restrict a2 \ + = bli_gemmt_get_next_a_upanel( a1, rstep_a, ir_inc ); \ + if ( bli_is_last_iter( i, m_iter, ir_tid, ir_nt ) ) \ + { \ + a2 = bli_gemmt_l_wrap_a_upanel( a_cast, rstep_a, \ + diagoffc_j, MR, NR ); \ + b2 = bli_gemmt_get_next_b_upanel( b1, cstep_b, jr_inc ); \ + if ( bli_is_last_iter( j, n_iter, jr_tid, jr_nt ) ) \ + b2 = b_cast; \ + } \ +\ + /* Save addresses of next panels of A and B to the auxinfo_t + object. */ \ + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ +\ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + MR, \ + NR, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + zero, \ + ct, rs_ct, cs_ct, \ + &aux, \ + cntx \ + ); \ +\ + /* Scale C and add the result to only the stored part. */ \ + PASTEMAC(ch,xpbys_mxn_l)( diagoffc_ij, \ + m_cur, n_cur, \ + ct, rs_ct, cs_ct, \ + beta_cast, \ + c11, rs_c, cs_c ); \ + } \ + else if ( bli_is_strictly_below_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ + { \ + ctype* restrict a1 = a_cast + i * rstep_a; \ + ctype* restrict c11 = c1 + i * rstep_c; \ +\ + /* Compute the addresses of the next panels of A and B. */ \ + ctype* restrict a2 \ + = bli_gemmt_get_next_a_upanel( a1, rstep_a, ir_inc ); \ + if ( bli_is_last_iter( i, m_iter, ir_tid, ir_nt ) ) \ + { \ + a2 = bli_gemmt_l_wrap_a_upanel( a_cast, rstep_a, \ + diagoffc_j, MR, NR ); \ + b2 = bli_gemmt_get_next_b_upanel( b1, cstep_b, jr_inc ); \ + if ( bli_is_last_iter( j, n_iter, jr_tid, jr_nt ) ) \ + b2 = b_cast; \ + } \ +\ + /* Save addresses of next panels of A and B to the auxinfo_t + object. */ \ + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ +\ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC0( gemmt_l_ker_var2b ) + diff --git a/frame/3/gemmt/other/bli_gemmt_u_ker_var2.c.prev b/frame/3/gemmt/other/bli_gemmt_u_ker_var2.c.prev new file mode 100644 index 0000000000..87d77ee554 --- /dev/null +++ b/frame/3/gemmt/other/bli_gemmt_u_ker_var2.c.prev @@ -0,0 +1,510 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define FUNCPTR_T gemmt_fp + +typedef void (*FUNCPTR_T) + ( + doff_t diagoffc, + pack_t schema_a, + pack_t schema_b, + dim_t m, + dim_t n, + dim_t k, + void* alpha, + void* a, inc_t cs_a, inc_t is_a, + dim_t pd_a, inc_t ps_a, + void* b, inc_t rs_b, inc_t is_b, + dim_t pd_b, inc_t ps_b, + void* beta, + void* c, inc_t rs_c, inc_t cs_c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ); + +static FUNCPTR_T GENARRAY(ftypes,gemmt_u_ker_var2); + + +void bli_gemmt_u_ker_var2 + ( + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ) +{ + const num_t dt_exec = bli_obj_exec_dt( c ); + + const doff_t diagoffc = bli_obj_diag_offset( c ); + + const pack_t schema_a = bli_obj_pack_schema( a ); + const pack_t schema_b = bli_obj_pack_schema( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + const dim_t k = bli_obj_width( a ); + + const void* buf_a = bli_obj_buffer_at_off( a ); + const inc_t cs_a = bli_obj_col_stride( a ); + const inc_t is_a = bli_obj_imag_stride( a ); + const dim_t pd_a = bli_obj_panel_dim( a ); + const inc_t ps_a = bli_obj_panel_stride( a ); + + const void* buf_b = bli_obj_buffer_at_off( b ); + const inc_t rs_b = bli_obj_row_stride( b ); + const inc_t is_b = bli_obj_imag_stride( b ); + const dim_t pd_b = bli_obj_panel_dim( b ); + const inc_t ps_b = bli_obj_panel_stride( b ); + + void* buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + // Detach and multiply the scalars attached to A and B. + obj_t scalar_a, scalar_b; + bli_obj_scalar_detach( a, &scalar_a ); + bli_obj_scalar_detach( b, &scalar_b ); + bli_mulsc( &scalar_a, &scalar_b ); + + // Grab the addresses of the internal scalar buffers for the scalar + // merged above and the scalar attached to C. + const void* buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); + const void* buf_beta = bli_obj_internal_scalar_buffer( c ); + + // Index into the type combination array to extract the correct + // function pointer. + ftypes[dt_exec] + ( + diagoffc, + schema_a, + schema_b, + m, + n, + k, + ( void* )buf_alpha, + ( void* )buf_a, cs_a, is_a, + pd_a, ps_a, + ( void* )buf_b, rs_b, is_b, + pd_b, ps_b, + ( void* )buf_beta, + buf_c, rs_c, cs_c, + ( cntx_t* )cntx, + rntm, + thread + ); +} + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + doff_t diagoffc, \ + pack_t schema_a, \ + pack_t schema_b, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* alpha, \ + void* a, inc_t cs_a, inc_t is_a, \ + dim_t pd_a, inc_t ps_a, \ + void* b, inc_t rs_b, inc_t is_b, \ + dim_t pd_b, inc_t ps_b, \ + void* beta, \ + void* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx, \ + rntm_t* rntm, \ + thrinfo_t* thread \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Alias some constants to simpler names. */ \ + const dim_t MR = pd_a; \ + const dim_t NR = pd_b; \ + /*const dim_t PACKMR = cs_a;*/ \ + /*const dim_t PACKNR = rs_b;*/ \ +\ + /* Query the context for the micro-kernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemm_ukr_ft) \ + gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ +\ + /* Temporary C buffer for edge cases. Note that the strides of this + temporary buffer are set so that they match the storage of the + original C matrix. For example, if C is column-stored, ct will be + column-stored as well. */ \ + ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const bool col_pref = bli_cntx_ukr_prefers_cols_dt( dt, BLIS_GEMM_VIR_UKR, cntx ); \ + const inc_t rs_ct = ( col_pref ? 1 : NR ); \ + const inc_t cs_ct = ( col_pref ? MR : 1 ); \ +\ + ctype* restrict zero = PASTEMAC(ch,0); \ + ctype* restrict a_cast = a; \ + ctype* restrict b_cast = b; \ + ctype* restrict c_cast = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ + ctype* restrict b1; \ + ctype* restrict c1; \ +\ + doff_t diagoffc_ij; \ + dim_t m_iter, m_left; \ + dim_t n_iter, n_left; \ + dim_t m_cur; \ + dim_t n_cur; \ + dim_t i, j, jp; \ + inc_t rstep_a; \ + inc_t cstep_b; \ + inc_t rstep_c, cstep_c; \ + auxinfo_t aux; \ +\ + /* + Assumptions/assertions: + rs_a == 1 + cs_a == PACKMR + pd_a == MR + ps_a == stride to next micro-panel of A + rs_b == PACKNR + cs_b == 1 + pd_b == NR + ps_b == stride to next micro-panel of B + rs_c == (no assumptions) + cs_c == (no assumptions) + */ \ +\ + /* If any dimension is zero, return immediately. */ \ + if ( bli_zero_dim3( m, n, k ) ) return; \ +\ + /* Safeguard: If the current panel of C is entirely below the diagonal, + it is not stored. So we do nothing. */ \ + if ( bli_is_strictly_below_diag_n( diagoffc, m, n ) ) return; \ +\ + /* If there is a zero region to the left of where the diagonal of C + intersects the top edge of the panel, adjust the pointer to C and B + and treat this case as if the diagonal offset were zero. + NOTE: It's possible that after this pruning that the diagonal offset + is still positive (though it is guaranteed to be less than NR). */ \ + if ( diagoffc > 0 ) \ + { \ + jp = diagoffc / NR; \ + j = jp * NR; \ + n = n - j; \ + diagoffc = diagoffc % NR; \ + c_cast = c_cast + (j )*cs_c; \ + b_cast = b_cast + (jp )*ps_b; \ + } \ +\ + /* If there is a zero region below where the diagonal of C intersects + the right edge of the panel, shrink it to prevent "no-op" iterations + from executing. */ \ + if ( -diagoffc + n < m ) \ + { \ + m = -diagoffc + n; \ + } \ +\ + /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ + PASTEMAC(ch,set0s_mxn)( MR, NR, \ + ct, rs_ct, cs_ct ); \ +\ + /* Compute number of primary and leftover components of the m and n + dimensions. */ \ + n_iter = n / NR; \ + n_left = n % NR; \ +\ + m_iter = m / MR; \ + m_left = m % MR; \ +\ + if ( n_left ) ++n_iter; \ + if ( m_left ) ++m_iter; \ +\ + /* Determine some increments used to step through A, B, and C. */ \ + rstep_a = ps_a; \ +\ + cstep_b = ps_b; \ +\ + rstep_c = rs_c * MR; \ + cstep_c = cs_c * NR; \ +\ + /* Save the pack schemas of A and B to the auxinfo_t object. */ \ + bli_auxinfo_set_schema_a( schema_a, &aux ); \ + bli_auxinfo_set_schema_b( schema_b, &aux ); \ +\ + /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ + bli_auxinfo_set_is_a( is_a, &aux ); \ + bli_auxinfo_set_is_b( is_b, &aux ); \ +\ + /* Save the desired output datatype (indicating no typecasting). */ \ + /*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \ +\ + /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) + loop around the microkernel. Here we query the thrinfo_t node for the + 1st (ir) loop around the microkernel. */ \ + thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \ +\ + /* Query the number of threads and thread ids for each loop. */ \ + dim_t jr_nt = bli_thread_n_way( thread ); \ + dim_t jr_tid = bli_thread_work_id( thread ); \ + dim_t ir_nt = bli_thread_n_way( caucus ); \ + dim_t ir_tid = bli_thread_work_id( caucus ); \ +\ + dim_t jr_start, jr_end; \ + dim_t ir_start, ir_end; \ + dim_t jr_inc, ir_inc; \ +\ + /* Note that we partition the 2nd loop into two regions: the triangular + part of C, and the rectangular portion. */ \ + dim_t n_iter_tri; \ + dim_t n_iter_rct; \ +\ + if ( bli_is_strictly_above_diag_n( diagoffc, m, n ) ) \ + { \ + /* If the entire panel of C does not intersect the diagonal, there is + no triangular region, and therefore we can skip the first set of + loops. */ \ + n_iter_tri = 0; \ + n_iter_rct = n_iter; \ + } \ + else \ + { \ + /* If the panel of C does intersect the diagonal, compute the number of + iterations in the triangular (or trapezoidal) region by dividing NR + into the number of rows in C. A non-zero remainder means we need to + add one additional iteration. That is, we want the triangular region + to contain as few columns of whole microtiles as possible while still + including all microtiles that intersect the diagonal. The number of + iterations in the rectangular region is computed as the remaining + number of iterations in the n dimension. */ \ + n_iter_tri = ( m + diagoffc ) / NR + ( ( m + diagoffc ) % NR ? 1 : 0 ); \ + n_iter_rct = n_iter - n_iter_tri; \ + } \ +\ + /* Use round-robin assignment of micropanels to threads in the 2nd loop + and the default (slab or rr) partitioning in the 1st loop for the + initial triangular region of C (if it exists). */ \ + bli_thread_range_jrir_rr( thread, n_iter_tri, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \ + bli_thread_range_jrir ( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \ +\ + /* Loop over the n dimension (NR columns at a time). */ \ + for ( j = jr_start; j < jr_end; j += jr_inc ) \ + { \ + ctype* restrict a1; \ + ctype* restrict c11; \ + ctype* restrict b2; \ +\ + b1 = b_cast + j * cstep_b; \ + c1 = c_cast + j * cstep_c; \ +\ + n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ +\ + /* Initialize our next panel of B to be the current panel of B. */ \ + b2 = b1; \ +\ + /* Interior loop over the m dimension (MR rows at a time). */ \ + for ( i = ir_start; i < ir_end; i += ir_inc ) \ + { \ + ctype* restrict a2; \ +\ + a1 = a_cast + i * rstep_a; \ + c11 = c1 + i * rstep_c; \ +\ + /* Compute the diagonal offset for the submatrix at (i,j). */ \ + diagoffc_ij = diagoffc - (doff_t)j*NR + (doff_t)i*MR; \ +\ + m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ +\ + /* Compute the addresses of the next panels of A and B. */ \ + a2 = bli_gemmt_get_next_a_upanel( a1, rstep_a, ir_inc ); \ + if ( bli_is_last_iter( i, m_iter, ir_tid, ir_nt ) ) \ + { \ + a2 = a_cast; \ + b2 = bli_gemmt_get_next_b_upanel( b1, cstep_b, jr_inc ); \ + if ( bli_is_last_iter_rr( j, n_iter, jr_tid, jr_nt ) ) \ + b2 = b_cast; \ + } \ +\ + /* Save addresses of next panels of A and B to the auxinfo_t + object. */ \ + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ +\ + /* If the diagonal intersects the current MR x NR submatrix, we + compute it the temporary buffer and then add in the elements + on or below the diagonal. + Otherwise, if the submatrix is strictly above the diagonal, + we compute and store as we normally would. + And if we're strictly below the diagonal, we do nothing and + continue. */ \ + if ( bli_intersects_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ + { \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + MR, \ + NR, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + zero, \ + ct, rs_ct, cs_ct, \ + &aux, \ + cntx \ + ); \ +\ + /* Scale C and add the result to only the stored part. */ \ + PASTEMAC(ch,xpbys_mxn_u)( diagoffc_ij, \ + m_cur, n_cur, \ + ct, rs_ct, cs_ct, \ + beta_cast, \ + c11, rs_c, cs_c ); \ + } \ + else if ( bli_is_strictly_above_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ + { \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + } \ + } \ +\ + /* If there is no rectangular region, then we're done. */ \ + if ( n_iter_rct == 0 ) return; \ +\ + /* Determine the thread range and increment for the 2nd loop of the + remaining rectangular region of C (and also use default partitioning + for the 1st loop). + NOTE: The definition of bli_thread_range_jrir() will depend on whether + slab or round-robin partitioning was requested at configure-time. */ \ + bli_thread_range_jrir( thread, n_iter_rct, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \ +\ + /* Advance the start and end iteration offsets for the rectangular region + by the number of iterations used for the triangular region. */ \ + jr_start += n_iter_tri; \ + jr_end += n_iter_tri; \ +\ + /* Loop over the n dimension (NR columns at a time). */ \ + for ( j = jr_start; j < jr_end; j += jr_inc ) \ + { \ + ctype* restrict a1; \ + ctype* restrict c11; \ + ctype* restrict b2; \ +\ + b1 = b_cast + j * cstep_b; \ + c1 = c_cast + j * cstep_c; \ +\ + n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ +\ + /* Initialize our next panel of B to be the current panel of B. */ \ + b2 = b1; \ +\ + /* Interior loop over the m dimension (MR rows at a time). */ \ + for ( i = ir_start; i < ir_end; i += ir_inc ) \ + { \ + ctype* restrict a2; \ +\ + a1 = a_cast + i * rstep_a; \ + c11 = c1 + i * rstep_c; \ +\ + /* No need to compute the diagonal offset for the rectangular + region. */ \ + /*diagoffc_ij = diagoffc - (doff_t)j*NR + (doff_t)i*MR;*/ \ +\ + m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ +\ + /* Compute the addresses of the next panels of A and B. */ \ + a2 = bli_gemmt_get_next_a_upanel( a1, rstep_a, ir_inc ); \ + if ( bli_is_last_iter( i, m_iter, ir_tid, ir_nt ) ) \ + { \ + a2 = a_cast; \ + b2 = bli_gemmt_get_next_b_upanel( b1, cstep_b, jr_inc ); \ + if ( bli_is_last_iter( j, n_iter, jr_tid, jr_nt ) ) \ + b2 = b_cast; \ + } \ +\ + /* Save addresses of next panels of A and B to the auxinfo_t + object. */ \ + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ +\ + /* If the diagonal intersects the current MR x NR submatrix, we + compute it the temporary buffer and then add in the elements + on or below the diagonal. + Otherwise, if the submatrix is strictly above the diagonal, + we compute and store as we normally would. + And if we're strictly below the diagonal, we do nothing and + continue. */ \ + { \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC0( gemmt_u_ker_var2 ) + diff --git a/frame/3/gemmt/other/bli_gemmt_u_ker_var2b.c.before b/frame/3/gemmt/other/bli_gemmt_u_ker_var2b.c.before new file mode 100644 index 0000000000..dbf8f389f1 --- /dev/null +++ b/frame/3/gemmt/other/bli_gemmt_u_ker_var2b.c.before @@ -0,0 +1,415 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define FUNCPTR_T gemmt_fp + +typedef void (*FUNCPTR_T) + ( + doff_t diagoffc, + pack_t schema_a, + pack_t schema_b, + dim_t m, + dim_t n, + dim_t k, + void* alpha, + void* a, inc_t cs_a, inc_t is_a, + dim_t pd_a, inc_t ps_a, + void* b, inc_t rs_b, inc_t is_b, + dim_t pd_b, inc_t ps_b, + void* beta, + void* c, inc_t rs_c, inc_t cs_c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ); + +static FUNCPTR_T GENARRAY(ftypes,gemmt_u_ker_var2b); + + +void bli_gemmt_u_ker_var2b + ( + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ) +{ + const num_t dt_exec = bli_obj_exec_dt( c ); + + const doff_t diagoffc = bli_obj_diag_offset( c ); + + const pack_t schema_a = bli_obj_pack_schema( a ); + const pack_t schema_b = bli_obj_pack_schema( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + const dim_t k = bli_obj_width( a ); + + const void* buf_a = bli_obj_buffer_at_off( a ); + const inc_t cs_a = bli_obj_col_stride( a ); + const inc_t is_a = bli_obj_imag_stride( a ); + const dim_t pd_a = bli_obj_panel_dim( a ); + const inc_t ps_a = bli_obj_panel_stride( a ); + + const void* buf_b = bli_obj_buffer_at_off( b ); + const inc_t rs_b = bli_obj_row_stride( b ); + const inc_t is_b = bli_obj_imag_stride( b ); + const dim_t pd_b = bli_obj_panel_dim( b ); + const inc_t ps_b = bli_obj_panel_stride( b ); + + void* buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + // Detach and multiply the scalars attached to A and B. + obj_t scalar_a, scalar_b; + bli_obj_scalar_detach( a, &scalar_a ); + bli_obj_scalar_detach( b, &scalar_b ); + bli_mulsc( &scalar_a, &scalar_b ); + + // Grab the addresses of the internal scalar buffers for the scalar + // merged above and the scalar attached to C. + const void* buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); + const void* buf_beta = bli_obj_internal_scalar_buffer( c ); + + // Index into the type combination array to extract the correct + // function pointer. + ftypes[dt_exec] + ( + diagoffc, + schema_a, + schema_b, + m, + n, + k, + ( void* )buf_alpha, + ( void* )buf_a, cs_a, is_a, + pd_a, ps_a, + ( void* )buf_b, rs_b, is_b, + pd_b, ps_b, + ( void* )buf_beta, + buf_c, rs_c, cs_c, + ( cntx_t* )cntx, + rntm, + thread + ); +} + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + doff_t diagoffc, \ + pack_t schema_a, \ + pack_t schema_b, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* alpha, \ + void* a, inc_t cs_a, inc_t is_a, \ + dim_t pd_a, inc_t ps_a, \ + void* b, inc_t rs_b, inc_t is_b, \ + dim_t pd_b, inc_t ps_b, \ + void* beta, \ + void* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx, \ + rntm_t* rntm, \ + thrinfo_t* thread \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Alias some constants to simpler names. */ \ + const dim_t MR = pd_a; \ + const dim_t NR = pd_b; \ + /*const dim_t PACKMR = cs_a;*/ \ + /*const dim_t PACKNR = rs_b;*/ \ +\ + /* Query the context for the micro-kernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemm_ukr_ft) \ + gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ +\ + /* Temporary C buffer for edge cases. Note that the strides of this + temporary buffer are set so that they match the storage of the + original C matrix. For example, if C is column-stored, ct will be + column-stored as well. */ \ + ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const bool col_pref = bli_cntx_ukr_prefers_cols_dt( dt, BLIS_GEMM_VIR_UKR, cntx ); \ + const inc_t rs_ct = ( col_pref ? 1 : NR ); \ + const inc_t cs_ct = ( col_pref ? MR : 1 ); \ +\ + ctype* restrict zero = PASTEMAC(ch,0); \ + ctype* restrict a_cast = a; \ + ctype* restrict b_cast = b; \ + ctype* restrict c_cast = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ +\ + auxinfo_t aux; \ +\ + /* + Assumptions/assertions: + rs_a == 1 + cs_a == PACKMR + pd_a == MR + ps_a == stride to next micro-panel of A + rs_b == PACKNR + cs_b == 1 + pd_b == NR + ps_b == stride to next micro-panel of B + rs_c == (no assumptions) + cs_c == (no assumptions) + */ \ +\ + /* If any dimension is zero, return immediately. */ \ + if ( bli_zero_dim3( m, n, k ) ) return; \ +\ + /* Safeguard: If the current panel of C is entirely below the diagonal, + it is not stored. So we do nothing. */ \ + if ( bli_is_strictly_below_diag_n( diagoffc, m, n ) ) return; \ +\ + /* If there is a zero region to the left of where the diagonal of C + intersects the top edge of the panel, adjust the pointer to C and B + and treat this case as if the diagonal offset were zero. + NOTE: It's possible that after this pruning that the diagonal offset + is still positive (though it is guaranteed to be less than NR). */ \ + if ( diagoffc > 0 ) \ + { \ + const dim_t jp = diagoffc / NR; \ + const dim_t j = jp * NR; \ +\ + n = n - j; \ + diagoffc = diagoffc % NR; \ + c_cast = c_cast + (j )*cs_c; \ + b_cast = b_cast + (jp )*ps_b; \ + } \ +\ + /* If there is a zero region below where the diagonal of C intersects + the right edge of the panel, shrink it to prevent "no-op" iterations + from executing. */ \ + if ( -diagoffc + n < m ) \ + { \ + m = -diagoffc + n; \ + } \ +\ + /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ + PASTEMAC(ch,set0s_mxn)( MR, NR, \ + ct, rs_ct, cs_ct ); \ +\ + /* Compute number of primary and leftover components of the m and n + dimensions. */ \ + const dim_t n_iter = n / NR + ( n % NR ? 1 : 0 ); \ + const dim_t n_left = n % NR; \ +\ + const dim_t m_iter = m / MR + ( m % MR ? 1 : 0 ); \ + const dim_t m_left = m % MR; \ +\ + /* Determine some increments used to step through A, B, and C. */ \ + const inc_t rstep_a = ps_a; \ +\ + const inc_t cstep_b = ps_b; \ +\ + const inc_t rstep_c = rs_c * MR; \ + const inc_t cstep_c = cs_c * NR; \ +\ + /* Save the pack schemas of A and B to the auxinfo_t object. */ \ + bli_auxinfo_set_schema_a( schema_a, &aux ); \ + bli_auxinfo_set_schema_b( schema_b, &aux ); \ +\ + /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ + bli_auxinfo_set_is_a( is_a, &aux ); \ + bli_auxinfo_set_is_b( is_b, &aux ); \ +\ + /* Save the virtual microkernel address and the params. */ \ + /*bli_auxinfo_set_ukr( gemm_ukr, &aux );*/ \ + /*bli_auxinfo_set_params( params, &aux );*/ \ +\ + /* Save the desired output datatype (indicating no typecasting). */ \ + /*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \ +\ + /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) + loop around the microkernel. Here we query the thrinfo_t node for the + 1st (ir) loop around the microkernel. */ \ + thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \ +\ + /* Query the number of threads and thread ids for each loop. */ \ + dim_t jr_nt = bli_thread_n_way( thread ); \ + dim_t jr_tid = bli_thread_work_id( thread ); \ + dim_t ir_nt = bli_thread_n_way( caucus ); \ + dim_t ir_tid = bli_thread_work_id( caucus ); \ +\ + dim_t jr_start, jr_end; \ + dim_t ir_start, ir_end; \ + dim_t jr_inc, ir_inc; \ +\ + /* Determine the thread range and increment for the 2nd and 1st loops. + NOTE: The definition of bli_thread_range_jrir() will depend on whether + slab or round-robin partitioning was requested at configure-time. */ \ + bli_thread_range_weighted_jr( thread, diagoffc, BLIS_UPPER, m, n, NR, \ + FALSE, &jr_start, &jr_end, &jr_inc ); \ + /*bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc );*/ \ + bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \ +\ +/* +printf( "bli_gemmt_u_ker_var2b(): tid %d: m n = %d %d st en in = %3d %3d %3d do %d\n", (int)jr_tid, (int)m, (int)n, (int)jr_start, (int)jr_end, (int)jr_inc, (int)diagoffc ); \ +*/ \ +\ + /* Loop over the n dimension (NR columns at a time). */ \ + for ( dim_t j = jr_start; j < jr_end; j += jr_inc ) \ + { \ + ctype* restrict b1 = b_cast + j * cstep_b; \ + ctype* restrict c1 = c_cast + j * cstep_c; \ +\ + /* Compute the diagonal offset for the column of microtiles at (0,j). */ \ + const doff_t diagoffc_j = diagoffc - (doff_t)j*NR; \ + const dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) \ + ? NR : n_left ); \ +\ + /* Initialize our next panel of B to be the current panel of B. */ \ + ctype* restrict b2 = b1; \ +\ + /* Interior loop over the m dimension (MR rows at a time). */ \ + for ( dim_t i = ir_start; i < ir_end; i += ir_inc ) \ + { \ + /* Compute the diagonal offset for the microtile at (i,j). */ \ + const doff_t diagoffc_ij = diagoffc_j + (doff_t)i*MR; \ + const dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) \ + ? MR : m_left ); \ +\ + /* If the diagonal intersects the current MR x NR submatrix, we + compute it the temporary buffer and then add in the elements + on or below the diagonal. + Otherwise, if the submatrix is strictly above the diagonal, + we compute and store as we normally would. + And if we're strictly below the diagonal, we do nothing and + continue on through the IR loop to consider the next MR x NR + microtile. */ \ + if ( bli_intersects_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ + { \ + ctype* restrict a1 = a_cast + i * rstep_a; \ + ctype* restrict c11 = c1 + i * rstep_c; \ +\ + /* Compute the addresses of the next panels of A and B. */ \ + ctype* restrict a2 \ + = bli_gemmt_get_next_a_upanel( a1, rstep_a, ir_inc ); \ + if ( bli_is_last_iter( i, m_iter, ir_tid, ir_nt ) ) \ + { \ + a2 = bli_gemmt_u_wrap_a_upanel( a_cast, rstep_a, \ + diagoffc_j, MR, NR ); \ + b2 = bli_gemmt_get_next_b_upanel( b1, cstep_b, jr_inc ); \ + if ( bli_is_last_iter( j, n_iter, jr_tid, jr_nt ) ) \ + b2 = b_cast; \ + } \ +\ + /* Save addresses of next panels of A and B to the auxinfo_t + object. */ \ + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ +\ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + MR, \ + NR, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + zero, \ + ct, rs_ct, cs_ct, \ + &aux, \ + cntx \ + ); \ +\ + /* Scale C and add the result to only the stored part. */ \ + PASTEMAC(ch,xpbys_mxn_u)( diagoffc_ij, \ + m_cur, n_cur, \ + ct, rs_ct, cs_ct, \ + beta_cast, \ + c11, rs_c, cs_c ); \ + } \ + else if ( bli_is_strictly_above_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ + { \ + ctype* restrict a1 = a_cast + i * rstep_a; \ + ctype* restrict c11 = c1 + i * rstep_c; \ +\ + /* Compute the addresses of the next panels of A and B. */ \ + ctype* restrict a2 \ + = bli_gemmt_get_next_a_upanel( a1, rstep_a, ir_inc ); \ + if ( bli_is_last_iter( i, m_iter, ir_tid, ir_nt ) ) \ + { \ + a2 = bli_gemmt_u_wrap_a_upanel( a_cast, rstep_a, \ + diagoffc_j, MR, NR ); \ + b2 = bli_gemmt_get_next_b_upanel( b1, cstep_b, jr_inc ); \ + if ( bli_is_last_iter( j, n_iter, jr_tid, jr_nt ) ) \ + b2 = b_cast; \ + } \ +\ + /* Save addresses of next panels of A and B to the auxinfo_t + object. */ \ + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ +\ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC0( gemmt_u_ker_var2b ) + diff --git a/frame/3/trmm/bli_trmm_ll_ker_var2.c b/frame/3/trmm/bli_trmm_ll_ker_var2.c index 3bc4e3c6b4..0c5cde72c1 100644 --- a/frame/3/trmm/bli_trmm_ll_ker_var2.c +++ b/frame/3/trmm/bli_trmm_ll_ker_var2.c @@ -37,11 +37,11 @@ void bli_trmm_ll_ker_var2 ( - const obj_t* a, - const obj_t* b, - const obj_t* c, - const cntx_t* cntx, - const cntl_t* cntl, + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntx_t* cntx, + const cntl_t* cntl, thrinfo_t* thread_par ) { @@ -83,10 +83,10 @@ void bli_trmm_ll_ker_var2 const void* buf_beta = bli_obj_internal_scalar_buffer( c ); // Alias some constants to simpler names. - const dim_t MR = pd_a; - const dim_t NR = pd_b; - const dim_t PACKMR = cs_a; - const dim_t PACKNR = rs_b; + const dim_t MR = pd_a; + const dim_t NR = pd_b; + const dim_t PACKMR = cs_a; + const dim_t PACKNR = rs_b; // Query the context for the micro-kernel address and cast it to its // function pointer type. @@ -140,50 +140,45 @@ void bli_trmm_ll_ker_var2 // Compute number of primary and leftover components of the m and n // dimensions. - dim_t n_iter = n / NR; - dim_t n_left = n % NR; + const dim_t n_iter = n / NR + ( n % NR ? 1 : 0 ); + const dim_t n_left = n % NR; - dim_t m_iter = m / MR; - dim_t m_left = m % MR; - - if ( n_left ) ++n_iter; - if ( m_left ) ++m_iter; + const dim_t m_iter = m / MR + ( m % MR ? 1 : 0 ); + const dim_t m_left = m % MR; // Determine some increments used to step through A, B, and C. - inc_t rstep_a = ps_a * dt_size; + const inc_t rstep_a = ps_a * dt_size; - inc_t cstep_b = ps_b * dt_size; + const inc_t cstep_b = ps_b * dt_size; - inc_t rstep_c = rs_c * MR * dt_size; - inc_t cstep_c = cs_c * NR * dt_size; + const inc_t rstep_c = rs_c * MR * dt_size; + const inc_t cstep_c = cs_c * NR * dt_size; - // Save the pack schemas of A and B to the auxinfo_t object. auxinfo_t aux; + + // Save the pack schemas of A and B to the auxinfo_t object. bli_auxinfo_set_schema_a( schema_a, &aux ); bli_auxinfo_set_schema_b( schema_b, &aux ); // The 'thread' argument points to the thrinfo_t node for the 2nd (jr) // loop around the microkernel. Here we query the thrinfo_t node for the // 1st (ir) loop around the microkernel. - //thrinfo_t* ir_thread = bli_thrinfo_sub_node( thread ); + thrinfo_t* thread = bli_thrinfo_sub_node( thread_par ); + //thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); // Query the number of threads and thread ids for each loop. - thrinfo_t* thread = bli_thrinfo_sub_node( thread_par ); - dim_t jr_nt = bli_thrinfo_n_way( thread ); - dim_t jr_tid = bli_thrinfo_work_id( thread ); - //dim_t ir_nt = bli_thrinfo_n_way( ir_thread ); - //dim_t ir_tid = bli_thrinfo_work_id( ir_thread ); + //const dim_t jr_nt = bli_thrinfo_n_way( thread ); + //const dim_t jr_tid = bli_thrinfo_work_id( thread ); + //const dim_t ir_nt = bli_thrinfo_n_way( caucus ); + //const dim_t ir_tid = bli_thrinfo_work_id( caucus ); - dim_t jr_start, jr_end; - //dim_t ir_start, ir_end; - dim_t jr_inc; + dim_t jr_start, jr_end, jr_inc; // Determine the thread range and increment for the 2nd loop. - // NOTE: The definition of bli_thread_range_jrir() will depend on whether + // NOTE: The definition of bli_thread_range_slrr() will depend on whether // slab or round-robin partitioning was requested at configure-time. // NOTE: Parallelism in the 1st loop is disabled for now. - bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); - //bli_thread_range_jrir_rr( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); + bli_thread_range_slrr( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); // Loop over the n dimension (NR columns at a time). for ( dim_t j = jr_start; j < jr_end; j += jr_inc ) @@ -191,20 +186,24 @@ void bli_trmm_ll_ker_var2 const char* b1 = b_cast + j * cstep_b; char* c1 = c_cast + j * cstep_c; - dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); + const dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) + ? NR : n_left ); // Initialize our next panel of B to be the current panel of B. const char* b2 = b1; + // Initialize pointers for stepping through the block of A and current + // column of microtiles of C. const char* a1 = a_cast; char* c11 = c1; // Loop over the m dimension (MR rows at a time). for ( dim_t i = 0; i < m_iter; ++i ) { - doff_t diagoffa_i = diagoffa + ( doff_t )i*MR; + const doff_t diagoffa_i = diagoffa + ( doff_t )i*MR; - dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); + const dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) + ? MR : m_left ); // If the current panel of A intersects the diagonal, scale C // by beta. If it is strictly below the diagonal, scale by one. @@ -215,8 +214,8 @@ void bli_trmm_ll_ker_var2 // Determine the offset to and length of the panel that was // packed so we can index into the corresponding location in // b1. - dim_t off_a1011 = 0; - dim_t k_a1011 = bli_min( diagoffa_i + MR, k ); + const dim_t off_a1011 = 0; + const dim_t k_a1011 = bli_min( diagoffa_i + MR, k ); // Compute the panel stride for the current diagonal- // intersecting micro-panel. @@ -230,13 +229,13 @@ void bli_trmm_ll_ker_var2 const char* b1_i = b1 + off_a1011 * PACKNR * dt_size; // Compute the addresses of the next panels of A and B. - const char* a2 = a1; - if ( bli_is_last_iter_rr( i, m_iter, 0, 1 ) ) + const char* a2 = bli_trmm_get_next_a_upanel( a1, rstep_a, 1 ); + if ( bli_is_last_iter_slrr( i, m_iter, 0, 1 ) ) { a2 = a_cast; - b2 = b1; - if ( bli_is_last_iter( j, n_iter, jr_tid, jr_nt ) ) - b2 = b_cast; + b2 = bli_trmm_get_next_b_upanel( b1, cstep_b, jr_inc ); + //if ( bli_is_last_iter_slrr( j, n_iter, jr_tid, jr_nt ) ) + // b2 = b_cast; } // Save addresses of next panels of A and B to the auxinfo_t @@ -268,13 +267,13 @@ void bli_trmm_ll_ker_var2 //if ( bli_trmm_my_iter( i, ir_thread ) ) { // Compute the addresses of the next panels of A and B. - const char* a2 = a1; - if ( bli_is_last_iter_rr( i, m_iter, 0, 1 ) ) + const char* a2 = bli_trmm_get_next_a_upanel( a1, rstep_a, 1 ); + if ( bli_is_last_iter_slrr( i, m_iter, 0, 1 ) ) { a2 = a_cast; - b2 = b1; - if ( bli_is_last_iter( j, n_iter, jr_tid, jr_nt ) ) - b2 = b_cast; + b2 = bli_trmm_get_next_b_upanel( b1, cstep_b, jr_inc ); + //if ( bli_is_last_iter_slrr( j, n_iter, jr_tid, jr_nt ) ) + // b2 = b_cast; } // Save addresses of next panels of A and B to the auxinfo_t @@ -306,6 +305,6 @@ void bli_trmm_ll_ker_var2 } } -//PASTEMAC(ch,fprintm)( stdout, "trmm_ll_ker_var2: a1", MR, k_a1011, a1, 1, MR, "%4.1f", "" ); -//PASTEMAC(ch,fprintm)( stdout, "trmm_ll_ker_var2: b1", k_a1011, NR, b1_i, NR, 1, "%4.1f", "" ); +//PASTEMAC(ch,printm)( "trmm_ll_ker_var2: a1", MR, k_a1011, a1, 1, MR, "%4.1f", "" ); +//PASTEMAC(ch,printm)( "trmm_ll_ker_var2: b1", k_a1011, NR, b1_i, NR, 1, "%4.1f", "" ); diff --git a/frame/3/trmm/bli_trmm_ll_ker_var2b.c b/frame/3/trmm/bli_trmm_ll_ker_var2b.c new file mode 100644 index 0000000000..bb6de00f5b --- /dev/null +++ b/frame/3/trmm/bli_trmm_ll_ker_var2b.c @@ -0,0 +1,365 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bli_trmm_ll_ker_var2b + ( + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntx_t* cntx, + const cntl_t* cntl, + thrinfo_t* thread_par + ) +{ + const num_t dt = bli_obj_exec_dt( c ); + const dim_t dt_size = bli_dt_size( dt ); + + doff_t diagoffa = bli_obj_diag_offset( a ); + + const pack_t schema_a = bli_obj_pack_schema( a ); + const pack_t schema_b = bli_obj_pack_schema( b ); + + dim_t m = bli_obj_length( c ); + dim_t n = bli_obj_width( c ); + dim_t k = bli_obj_width( a ); + + const void* buf_a = bli_obj_buffer_at_off( a ); + const inc_t cs_a = bli_obj_col_stride( a ); + const dim_t pd_a = bli_obj_panel_dim( a ); + const inc_t ps_a = bli_obj_panel_stride( a ); + + const void* buf_b = bli_obj_buffer_at_off( b ); + const inc_t rs_b = bli_obj_row_stride( b ); + const dim_t pd_b = bli_obj_panel_dim( b ); + const inc_t ps_b = bli_obj_panel_stride( b ); + + void* buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + // Detach and multiply the scalars attached to A and B. + obj_t scalar_a, scalar_b; + bli_obj_scalar_detach( a, &scalar_a ); + bli_obj_scalar_detach( b, &scalar_b ); + bli_mulsc( &scalar_a, &scalar_b ); + + // Grab the addresses of the internal scalar buffers for the scalar + // merged above and the scalar attached to C. + const void* buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); + const void* buf_beta = bli_obj_internal_scalar_buffer( c ); + + // Alias some constants to simpler names. + const dim_t MR = pd_a; + const dim_t NR = pd_b; + const dim_t PACKMR = cs_a; + const dim_t PACKNR = rs_b; + + // Query the context for the micro-kernel address and cast it to its + // function pointer type. + gemm_ukr_vft gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); + + const void* one = bli_obj_buffer_for_const( dt, &BLIS_ONE ); + const char* a_cast = buf_a; + const char* b_cast = buf_b; + char* c_cast = buf_c; + const char* alpha_cast = buf_alpha; + const char* beta_cast = buf_beta; + + /* + Assumptions/assertions: + rs_a == 1 + cs_a == PACKMR + pd_a == MR + ps_a == stride to next micro-panel of A + rs_b == PACKNR + cs_b == 1 + pd_b == NR + ps_b == stride to next micro-panel of B + rs_c == (no assumptions) + cs_c == (no assumptions) + */ + + // Safety trap: Certain indexing within this macro-kernel does not + // work as intended if both MR and NR are odd. + if ( ( bli_is_odd( PACKMR ) && bli_is_odd( NR ) ) || + ( bli_is_odd( PACKNR ) && bli_is_odd( MR ) ) ) bli_abort(); + + // If any dimension is zero, return immediately. + if ( bli_zero_dim3( m, n, k ) ) return; + + // Safeguard: If the current block of A is entirely above the diagonal, + // it is implicitly zero. So we do nothing. + if ( bli_is_strictly_above_diag_n( diagoffa, m, k ) ) return; + + // If there is a zero region above where the diagonal of A intersects the + // left edge of the block, adjust the pointer to C and treat this case as + // if the diagonal offset were zero. This skips over the region that was + // not packed. (Note we assume the diagonal offset is a multiple of MR; + // this assumption will hold as long as the cache blocksizes KC nd MC are + // each a multiple of MR.) + if ( diagoffa < 0 ) + { + m += diagoffa; + c_cast -= diagoffa * rs_c * dt_size; + diagoffa = 0; + } + + // Compute number of primary and leftover components of the m and n + // dimensions. + const dim_t n_iter = n / NR + ( n % NR ? 1 : 0 ); + const dim_t n_left = n % NR; + + const dim_t m_iter = m / MR + ( m % MR ? 1 : 0 ); + const dim_t m_left = m % MR; + + // Computing the number of MR x MR tiles in the k dimension is needed + // when computing the thread ranges below. + const dim_t k_iter = k / MR + ( k % MR ? 1 : 0 ); + + // Determine some increments used to step through A, B, and C. + const inc_t rstep_a = ps_a * dt_size; + + const inc_t cstep_b = ps_b * dt_size; + + const inc_t rstep_c = rs_c * MR * dt_size; + const inc_t cstep_c = cs_c * NR * dt_size; + + auxinfo_t aux; + + // Save the pack schemas of A and B to the auxinfo_t object. + bli_auxinfo_set_schema_a( schema_a, &aux ); + bli_auxinfo_set_schema_b( schema_b, &aux ); + + // The 'thread' argument points to the thrinfo_t node for the 2nd (jr) + // loop around the microkernel. Here we query the thrinfo_t node for the + // 1st (ir) loop around the microkernel. + thrinfo_t* thread = bli_thrinfo_sub_node( thread_par ); + //thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); + + // Query the number of threads and thread ids for the JR loop. + const dim_t jr_nt = bli_thrinfo_n_way( thread ); + const dim_t jr_tid = bli_thrinfo_work_id( thread ); + //const dim_t ir_nt = bli_thrinfo_n_way( caucus ); + //const dim_t ir_tid = bli_thrinfo_work_id( caucus ); + + dim_t jr_st, ir_st; + const dim_t n_ut_for_me + = + bli_thread_range_tlb_trmm_ll( jr_nt, jr_tid, diagoffa, m_iter, n_iter, k_iter, + MR, NR, &jr_st, &ir_st ); + +#if 0 + printf( "tid: %ld m,n,k_iter: %ld %ld %ld\n", tid, m_iter, n_iter, k_iter ); + printf( "tid: %ld trmm_ll_tlb begins at: %ld %ld (n_ut: %ld)\n", + tid, jr_st, ir_st, n_ut_for_me ); +#endif + + // It's possible that there are so few microtiles relative to the number + // of threads that one or more threads gets no work. If that happens, those + // threads can return early. + if ( n_ut_for_me == 0 ) return; + + // Start the jr/ir loops with the current thread's microtile offsets computed + // by bli_thread_range_tlb_trmm_ll(). + dim_t i = ir_st; + dim_t j = jr_st; + + // Initialize a counter to track the number of microtiles computed by the + // current thread. + dim_t ut = 0; + + const char* a1 = a_cast; + + // Get pointers into position by stepping through to the ith micropanel of + // A and ith microtile of C (within the appropriate column of microtiles). + for ( dim_t ii = 0; ii < ir_st; ++ii ) + { + const doff_t diagoffa_ii = diagoffa + ( doff_t )ii*MR; + + if ( bli_intersects_diag_n( diagoffa_ii, MR, k ) ) + { + // Determine the length of the panel that was packed. + const dim_t k_a1011 = bli_min( diagoffa_ii + MR, k ); + + // Compute the panel stride for the current diagonal- + // intersecting micro-panel. + inc_t ps_a_cur = k_a1011 * PACKMR; + ps_a_cur += ( bli_is_odd( ps_a_cur ) ? 1 : 0 ); + ps_a_cur *= dt_size; + + a1 += ps_a_cur; + } + else if ( bli_is_strictly_below_diag_n( diagoffa_ii, MR, k ) ) + { + a1 += rstep_a; + } + } + + // Loop over the n dimension (NR columns at a time). + for ( ; true; ++j ) + { + const char* b1 = b_cast + j * cstep_b; + char* c1 = c_cast + j * cstep_c; + + const dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) + ? NR : n_left ); + + // Initialize our next panel of B to be the current panel of B. + const char* b2 = b1; + + bli_auxinfo_set_next_b( b2, &aux ); + + // Loop over the m dimension (MR rows at a time). + for ( ; i < m_iter; ++i ) + { + char* c11 = c1 + i * rstep_c; + + const doff_t diagoffa_i = diagoffa + ( doff_t )i*MR; + + const dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) + ? MR : m_left ); + + // If the current panel of A intersects the diagonal, scale C + // by beta. If it is strictly below the diagonal, scale by one. + // This allows the current macro-kernel to work for both trmm + // and trmm3. + if ( bli_intersects_diag_n( diagoffa_i, MR, k ) ) + { + // Determine the offset to and length of the panel that was + // packed so we can index into the corresponding location in B. + const dim_t off_a1011 = 0; + const dim_t k_a1011 = bli_min( diagoffa_i + MR, k ); + + // Compute the panel stride for the current diagonal- + // intersecting micro-panel. + inc_t ps_a_cur = k_a1011 * PACKMR; + ps_a_cur += ( bli_is_odd( ps_a_cur ) ? 1 : 0 ); + ps_a_cur *= dt_size; + + const char* b1_i = b1 + off_a1011 * PACKNR * dt_size; + + // Compute the addresses of the next panels of A and B. + const char* a2 = bli_trmm_get_next_a_upanel( a1, ps_a_cur, 1 ); + if ( bli_is_last_iter_sl( i, m_iter ) ) + { + a2 = a_cast; + b2 = bli_trmm_get_next_b_upanel( b1, cstep_b, 1 ); + bli_auxinfo_set_next_b( b2, &aux ); + } + + // Save addresses of next panels of A and B to the auxinfo_t + // object. + bli_auxinfo_set_next_a( a2, &aux ); + + // Invoke the gemm micro-kernel. + gemm_ukr + ( + m_cur, + n_cur, + k_a1011, + ( void* )alpha_cast, + ( void* )a1, + ( void* )b1_i, + ( void* )beta_cast, + c11, rs_c, cs_c, + &aux, + ( cntx_t* )cntx + ); + + // Increment the microtile counter and check if the thread is done. + ut += 1; if ( ut == n_ut_for_me ) return; + + a1 += ps_a_cur; + } + else if ( bli_is_strictly_below_diag_n( diagoffa_i, MR, k ) ) + { + // Compute the addresses of the next panels of A and B. + const char* a2 = bli_trmm_get_next_a_upanel( a1, rstep_a, 1 ); + if ( bli_is_last_iter_sl( i, m_iter ) ) + { + a2 = a_cast; + b2 = bli_trmm_get_next_b_upanel( b1, cstep_b, 1 ); + bli_auxinfo_set_next_b( b2, &aux ); + } + + // Save addresses of next panels of A and B to the auxinfo_t + // object. + bli_auxinfo_set_next_a( a2, &aux ); + + // Invoke the gemm micro-kernel. + gemm_ukr + ( + m_cur, + n_cur, + k, + ( void* )alpha_cast, + ( void* )a1, + ( void* )b1, + ( void* )one, + c11, rs_c, cs_c, + &aux, + ( cntx_t* )cntx + ); + + // Increment the microtile counter and check if the thread is done. + ut += 1; if ( ut == n_ut_for_me ) return; + + a1 += rstep_a; + } + } + + // Upon reaching the end of the column of microtiles, reset the ir + // loop index so that we're ready to start the next pass through the + // m dimension (i.e., the next jr loop iteration). + i = 0; + + // Reset the a1 pointer to the beginning of the packed matrix A. + a1 = a_cast; + } +} + +//PASTEMAC(ch,printm)( "trmm_ll_ker_var2b: a1", MR, k_a1011, a1, 1, MR, "%4.1f", "" ); +//PASTEMAC(ch,printm)( "trmm_ll_ker_var2b: b1", k_a1011, NR, b1_i, NR, 1, "%4.1f", "" ); + +//printf( "tid: %ld intersects diag. j,i: %ld %ld (ut: %ld)\n", tid, j, i, ut ); +//printf( "tid: %ld strictbelow diag j,i: %ld %ld (ut: %ld)\n", tid, j, i, ut ); + +//printf( "tid: %ld incrementing by ps_a_cur: %ld (k_a1011: %ld)\n", +// tid, ps_a_cur, k_a1011 ); +//printf( "tid: %ld incrementing by rstep_a: %ld (k : %ld)\n", +// tid, rstep_a, k ); + diff --git a/frame/3/trmm/bli_trmm_lu_ker_var2.c b/frame/3/trmm/bli_trmm_lu_ker_var2.c index 265e21a66a..039bcc2926 100644 --- a/frame/3/trmm/bli_trmm_lu_ker_var2.c +++ b/frame/3/trmm/bli_trmm_lu_ker_var2.c @@ -37,11 +37,11 @@ void bli_trmm_lu_ker_var2 ( - const obj_t* a, - const obj_t* b, - const obj_t* c, - const cntx_t* cntx, - const cntl_t* cntl, + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntx_t* cntx, + const cntl_t* cntl, thrinfo_t* thread_par ) { @@ -83,10 +83,10 @@ void bli_trmm_lu_ker_var2 const void* buf_beta = bli_obj_internal_scalar_buffer( c ); // Alias some constants to simpler names. - const dim_t MR = pd_a; - const dim_t NR = pd_b; - const dim_t PACKMR = cs_a; - const dim_t PACKNR = rs_b; + const dim_t MR = pd_a; + const dim_t NR = pd_b; + const dim_t PACKMR = cs_a; + const dim_t PACKNR = rs_b; // Query the context for the micro-kernel address and cast it to its // function pointer type. @@ -147,50 +147,45 @@ void bli_trmm_lu_ker_var2 // Compute number of primary and leftover components of the m and n // dimensions. - dim_t n_iter = n / NR; - dim_t n_left = n % NR; + const dim_t n_iter = n / NR + ( n % NR ? 1 : 0 ); + const dim_t n_left = n % NR; - dim_t m_iter = m / MR; - dim_t m_left = m % MR; - - if ( n_left ) ++n_iter; - if ( m_left ) ++m_iter; + const dim_t m_iter = m / MR + ( m % MR ? 1 : 0 ); + const dim_t m_left = m % MR; // Determine some increments used to step through A, B, and C. - inc_t rstep_a = ps_a * dt_size; + const inc_t rstep_a = ps_a * dt_size; - inc_t cstep_b = ps_b * dt_size; + const inc_t cstep_b = ps_b * dt_size; - inc_t rstep_c = rs_c * MR * dt_size; - inc_t cstep_c = cs_c * NR * dt_size; + const inc_t rstep_c = rs_c * MR * dt_size; + const inc_t cstep_c = cs_c * NR * dt_size; - // Save the pack schemas of A and B to the auxinfo_t object. auxinfo_t aux; + + // Save the pack schemas of A and B to the auxinfo_t object. bli_auxinfo_set_schema_a( schema_a, &aux ); bli_auxinfo_set_schema_b( schema_b, &aux ); // The 'thread' argument points to the thrinfo_t node for the 2nd (jr) // loop around the microkernel. Here we query the thrinfo_t node for the // 1st (ir) loop around the microkernel. - //thrinfo_t* ir_thread = bli_thrinfo_sub_node( thread ); + thrinfo_t* thread = bli_thrinfo_sub_node( thread_par ); + //thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); // Query the number of threads and thread ids for each loop. - thrinfo_t* thread = bli_thrinfo_sub_node( thread_par ); - dim_t jr_nt = bli_thrinfo_n_way( thread ); - dim_t jr_tid = bli_thrinfo_work_id( thread ); - //dim_t ir_nt = bli_thrinfo_n_way( ir_thread ); - //dim_t ir_tid = bli_thrinfo_work_id( ir_thread ); + //const dim_t jr_nt = bli_thrinfo_n_way( thread ); + //const dim_t jr_tid = bli_thrinfo_work_id( thread ); + //const dim_t ir_nt = bli_thrinfo_n_way( caucus ); + //const dim_t ir_tid = bli_thrinfo_work_id( caucus ); - dim_t jr_start, jr_end; - //dim_t ir_start, ir_end; - dim_t jr_inc; + dim_t jr_start, jr_end, jr_inc; // Determine the thread range and increment for the 2nd loop. - // NOTE: The definition of bli_thread_range_jrir() will depend on whether + // NOTE: The definition of bli_thread_range_slrr() will depend on whether // slab or round-robin partitioning was requested at configure-time. // NOTE: Parallelism in the 1st loop is disabled for now. - bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); - //bli_thread_range_jrir_rr( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); + bli_thread_range_slrr( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); // Loop over the n dimension (NR columns at a time). for ( dim_t j = jr_start; j < jr_end; j += jr_inc ) @@ -198,20 +193,24 @@ void bli_trmm_lu_ker_var2 const char* b1 = b_cast + j * cstep_b; char* c1 = c_cast + j * cstep_c; - dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); + const dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) + ? NR : n_left ); // Initialize our next panel of B to be the current panel of B. const char* b2 = b1; + // Initialize pointers for stepping through the block of A and current + // column of microtiles of C. const char* a1 = a_cast; char* c11 = c1; // Loop over the m dimension (MR rows at a time). for ( dim_t i = 0; i < m_iter; ++i ) { - doff_t diagoffa_i = diagoffa + ( doff_t )i*MR; + const doff_t diagoffa_i = diagoffa + ( doff_t )i*MR; - dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); + const dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) + ? MR : m_left ); // If the current panel of A intersects the diagonal, scale C // by beta. If it is strictly above the diagonal, scale by one. @@ -222,8 +221,8 @@ void bli_trmm_lu_ker_var2 // Determine the offset to and length of the panel that was // packed so we can index into the corresponding location in // b1. - dim_t off_a1112 = diagoffa_i; - dim_t k_a1112 = k - off_a1112; + const dim_t off_a1112 = diagoffa_i; + const dim_t k_a1112 = k - off_a1112; // Compute the panel stride for the current diagonal- // intersecting micro-panel. @@ -237,13 +236,13 @@ void bli_trmm_lu_ker_var2 const char* b1_i = b1 + off_a1112 * PACKNR * dt_size; // Compute the addresses of the next panels of A and B. - const char* a2 = a1; - if ( bli_is_last_iter_rr( i, m_iter, 0, 1 ) ) + const char* a2 = bli_trmm_get_next_a_upanel( a1, rstep_a, 1 ); + if ( bli_is_last_iter_slrr( i, m_iter, 0, 1 ) ) { a2 = a_cast; - b2 = b1; - if ( bli_is_last_iter( j, n_iter, jr_tid, jr_nt ) ) - b2 = b_cast; + b2 = bli_trmm_get_next_b_upanel( b1, cstep_b, jr_inc ); + //if ( bli_is_last_iter_slrr( j, n_iter, jr_tid, jr_nt ) ) + // b2 = b_cast; } // Save addresses of next panels of A and B to the auxinfo_t @@ -275,13 +274,13 @@ void bli_trmm_lu_ker_var2 //if ( bli_trmm_my_iter( i, ir_thread ) ) { // Compute the addresses of the next panels of A and B. - const char* a2 = a1; - if ( bli_is_last_iter_rr( i, m_iter, 0, 1 ) ) + const char* a2 = bli_trmm_get_next_a_upanel( a1, rstep_a, 1 ); + if ( bli_is_last_iter_slrr( i, m_iter, 0, 1 ) ) { a2 = a_cast; - b2 = b1; - if ( bli_is_last_iter( j, n_iter, jr_tid, jr_nt ) ) - b2 = b_cast; + b2 = bli_trmm_get_next_b_upanel( b1, cstep_b, jr_inc ); + //if ( bli_is_last_iter_slrr( j, n_iter, jr_tid, jr_nt ) ) + // b2 = b_cast; } // Save addresses of next panels of A and B to the auxinfo_t @@ -313,6 +312,6 @@ void bli_trmm_lu_ker_var2 } } -//PASTEMAC(ch,fprintm)( stdout, "trmm_lu_ker_var2: a1", MR, k_a1112, a1, 1, MR, "%4.1f", "" ); -//PASTEMAC(ch,fprintm)( stdout, "trmm_lu_ker_var2: b1", k_a1112, NR, b1_i, NR, 1, "%4.1f", "" ); +//PASTEMAC(ch,printm)( "trmm_lu_ker_var2: a1", MR, k_a1112, a1, 1, MR, "%4.1f", "" ); +//PASTEMAC(ch,printm)( "trmm_lu_ker_var2: b1", k_a1112, NR, b1_i, NR, 1, "%4.1f", "" ); diff --git a/frame/3/trmm/bli_trmm_lu_ker_var2b.c b/frame/3/trmm/bli_trmm_lu_ker_var2b.c new file mode 100644 index 0000000000..39640ad6bf --- /dev/null +++ b/frame/3/trmm/bli_trmm_lu_ker_var2b.c @@ -0,0 +1,366 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bli_trmm_lu_ker_var2b + ( + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntx_t* cntx, + const cntl_t* cntl, + thrinfo_t* thread_par + ) +{ + const num_t dt = bli_obj_exec_dt( c ); + const dim_t dt_size = bli_dt_size( dt ); + + doff_t diagoffa = bli_obj_diag_offset( a ); + + const pack_t schema_a = bli_obj_pack_schema( a ); + const pack_t schema_b = bli_obj_pack_schema( b ); + + dim_t m = bli_obj_length( c ); + dim_t n = bli_obj_width( c ); + dim_t k = bli_obj_width( a ); + + const void* buf_a = bli_obj_buffer_at_off( a ); + const inc_t cs_a = bli_obj_col_stride( a ); + const dim_t pd_a = bli_obj_panel_dim( a ); + const inc_t ps_a = bli_obj_panel_stride( a ); + + const void* buf_b = bli_obj_buffer_at_off( b ); + const inc_t rs_b = bli_obj_row_stride( b ); + const dim_t pd_b = bli_obj_panel_dim( b ); + const inc_t ps_b = bli_obj_panel_stride( b ); + + void* buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + // Detach and multiply the scalars attached to A and B. + obj_t scalar_a, scalar_b; + bli_obj_scalar_detach( a, &scalar_a ); + bli_obj_scalar_detach( b, &scalar_b ); + bli_mulsc( &scalar_a, &scalar_b ); + + // Grab the addresses of the internal scalar buffers for the scalar + // merged above and the scalar attached to C. + const void* buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); + const void* buf_beta = bli_obj_internal_scalar_buffer( c ); + + // Alias some constants to simpler names. + const dim_t MR = pd_a; + const dim_t NR = pd_b; + const dim_t PACKMR = cs_a; + const dim_t PACKNR = rs_b; + + // Query the context for the micro-kernel address and cast it to its + // function pointer type. + gemm_ukr_vft gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); + + const void* one = bli_obj_buffer_for_const( dt, &BLIS_ONE ); + const char* a_cast = buf_a; + const char* b_cast = buf_b; + char* c_cast = buf_c; + const char* alpha_cast = buf_alpha; + const char* beta_cast = buf_beta; + + /* + Assumptions/assertions: + rs_a == 1 + cs_a == PACKMR + pd_a == MR + ps_a == stride to next micro-panel of A + rs_b == PACKNR + cs_b == 1 + pd_b == NR + ps_b == stride to next micro-panel of B + rs_c == (no assumptions) + cs_c == (no assumptions) + */ + + // Safety trap: Certain indexing within this macro-kernel does not + // work as intended if both MR and NR are odd. + if ( ( bli_is_odd( PACKMR ) && bli_is_odd( NR ) ) || + ( bli_is_odd( PACKNR ) && bli_is_odd( MR ) ) ) bli_abort(); + + // If any dimension is zero, return immediately. + if ( bli_zero_dim3( m, n, k ) ) return; + + // Safeguard: If the current block of A is entirely below the diagonal, + // it is implicitly zero. So we do nothing. + if ( bli_is_strictly_below_diag_n( diagoffa, m, k ) ) return; + + // If there is a zero region to the left of where the diagonal of A + // intersects the top edge of the block, adjust the pointer to B and + // treat this case as if the diagonal offset were zero. Note that we + // don't need to adjust the pointer to A since packm would have simply + // skipped over the region that was not stored. (Note we assume the + // diagonal offset is a multiple of MR; this assumption will hold as + // long as the cache blocksizes KC nd MC are each a multiple of MR.) + if ( diagoffa > 0 ) + { + k -= diagoffa; + b_cast += diagoffa * PACKNR * dt_size; + diagoffa = 0; + } + + // If there is a zero region below where the diagonal of A intersects the + // right side of the block, shrink it to prevent "no-op" iterations from + // executing. + if ( -diagoffa + k < m ) + { + m = -diagoffa + k; + } + + // Compute number of primary and leftover components of the m and n + // dimensions. + const dim_t n_iter = n / NR + ( n % NR ? 1 : 0 ); + const dim_t n_left = n % NR; + + const dim_t m_iter = m / MR + ( m % MR ? 1 : 0 ); + const dim_t m_left = m % MR; + + // Computing the number of MR x MR tiles in the k dimension is needed + // when computing the thread ranges below. + const dim_t k_iter = k / MR + ( k % MR ? 1 : 0 ); + + // Determine some increments used to step through A, B, and C. + const inc_t rstep_a = ps_a * dt_size; + + const inc_t cstep_b = ps_b * dt_size; + + const inc_t rstep_c = rs_c * MR * dt_size; + const inc_t cstep_c = cs_c * NR * dt_size; + + auxinfo_t aux; + + // Save the pack schemas of A and B to the auxinfo_t object. + bli_auxinfo_set_schema_a( schema_a, &aux ); + bli_auxinfo_set_schema_b( schema_b, &aux ); + + // The 'thread' argument points to the thrinfo_t node for the 2nd (jr) + // loop around the microkernel. Here we query the thrinfo_t node for the + // 1st (ir) loop around the microkernel. + thrinfo_t* thread = bli_thrinfo_sub_node( thread_par ); + //thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); + + // Query the number of threads and thread ids for each loop. + const dim_t jr_nt = bli_thrinfo_n_way( thread ); + const dim_t jr_tid = bli_thrinfo_work_id( thread ); + //const dim_t ir_nt = bli_thrinfo_n_way( caucus ); + //const dim_t ir_tid = bli_thrinfo_work_id( caucus ); + + dim_t jr_st, ir_st; + const dim_t n_ut_for_me + = + bli_thread_range_tlb_trmm_lu( jr_nt, jr_tid, diagoffa, m_iter, n_iter, k_iter, + MR, NR, &jr_st, &ir_st ); + + // It's possible that there are so few microtiles relative to the number + // of threads that one or more threads gets no work. If that happens, those + // threads can return early. + if ( n_ut_for_me == 0 ) return; + + // Start the jr/ir loops with the current thread's microtile offsets computed + // by bli_thread_range_tlb_trmm_ll(). + dim_t i = ir_st; + dim_t j = jr_st; + + // Initialize a counter to track the number of microtiles computed by the + // current thread. + dim_t ut = 0; + + const char* a1 = a_cast; + + // Get pointers into position by stepping through to the ith micropanel of + // A and ith microtile of C (within the appropriate column of microtiles). + for ( dim_t ii = 0; ii < ir_st; ++ii ) + { + const doff_t diagoffa_ii = diagoffa + ( doff_t )ii*MR; + + if ( bli_intersects_diag_n( diagoffa_ii, MR, k ) ) + { + // Determine the length of the panel that was packed. + const dim_t k_a1112 = k - diagoffa_ii; + + // Compute the panel stride for the current diagonal- + // intersecting micro-panel. + inc_t ps_a_cur = k_a1112 * PACKMR; + ps_a_cur += ( bli_is_odd( ps_a_cur ) ? 1 : 0 ); + ps_a_cur *= dt_size; + + a1 += ps_a_cur; + } + else if ( bli_is_strictly_above_diag_n( diagoffa_ii, MR, k ) ) + { + a1 += rstep_a; + } + } + + // Loop over the n dimension (NR columns at a time). + for ( ; true; ++j ) + { + const char* b1 = b_cast + j * cstep_b; + char* c1 = c_cast + j * cstep_c; + + const dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) + ? NR : n_left ); + + // Initialize our next panel of B to be the current panel of B. + const char* b2 = b1; + + bli_auxinfo_set_next_b( b2, &aux ); + + // Loop over the m dimension (MR rows at a time). + for ( ; i < m_iter; ++i ) + { + char* c11 = c1 + i * rstep_c; + + const doff_t diagoffa_i = diagoffa + ( doff_t )i*MR; + + const dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) + ? MR : m_left ); + + // If the current panel of A intersects the diagonal, scale C + // by beta. If it is strictly above the diagonal, scale by one. + // This allows the current macro-kernel to work for both trmm + // and trmm3. + if ( bli_intersects_diag_n( diagoffa_i, MR, k ) ) + { + // Determine the offset to and length of the panel that was + // packed so we can index into the corresponding location in B. + const dim_t off_a1112 = diagoffa_i; + const dim_t k_a1112 = k - off_a1112; + + // Compute the panel stride for the current diagonal- + // intersecting micro-panel. + inc_t ps_a_cur = k_a1112 * PACKMR; + ps_a_cur += ( bli_is_odd( ps_a_cur ) ? 1 : 0 ); + ps_a_cur *= dt_size; + + const char* b1_i = b1 + off_a1112 * PACKNR * dt_size; + + // Compute the addresses of the next panels of A and B. + const char* a2 = bli_trmm_get_next_a_upanel( a1, ps_a_cur, 1 ); + if ( bli_is_last_iter_sl( i, m_iter ) ) + { + a2 = a_cast; + b2 = bli_trmm_get_next_b_upanel( b1, cstep_b, 1 ); + bli_auxinfo_set_next_b( b2, &aux ); + } + + // Save addresses of next panels of A and B to the auxinfo_t + // object. + bli_auxinfo_set_next_a( a2, &aux ); + + // Invoke the gemm micro-kernel. + gemm_ukr + ( + m_cur, + n_cur, + k_a1112, + ( void* )alpha_cast, + ( void* )a1, + ( void* )b1_i, + ( void* )beta_cast, + c11, rs_c, cs_c, + &aux, + ( cntx_t* )cntx + ); + + // Increment the microtile counter and check if the thread is done. + ut += 1; if ( ut == n_ut_for_me ) return; + + a1 += ps_a_cur; + } + else if ( bli_is_strictly_above_diag_n( diagoffa_i, MR, k ) ) + { + // Compute the addresses of the next panels of A and B. + const char* a2 = bli_trmm_get_next_a_upanel( a1, rstep_a, 1 ); + if ( bli_is_last_iter_sl( i, m_iter ) ) + { + a2 = a_cast; + b2 = bli_trmm_get_next_b_upanel( b1, cstep_b, 1 ); + bli_auxinfo_set_next_b( b2, &aux ); + } + + // Save addresses of next panels of A and B to the auxinfo_t + // object. + bli_auxinfo_set_next_a( a2, &aux ); + + // Invoke the gemm micro-kernel. + gemm_ukr + ( + m_cur, + n_cur, + k, + ( void* )alpha_cast, + ( void* )a1, + ( void* )b1, + ( void* )one, + c11, rs_c, cs_c, + &aux, + ( cntx_t* )cntx + ); + + // Increment the microtile counter and check if the thread is done. + ut += 1; if ( ut == n_ut_for_me ) return; + + a1 += rstep_a; + } + } + + // Upon reaching the end of the column of microtiles, reset the ir + // loop index so that we're ready to start the next pass through the + // m dimension (i.e., the next jr loop iteration). + i = 0; + + // Reset the a1 pointer to the beginning of the packed matrix A. + a1 = a_cast; + } +} + +//PASTEMAC(ch,printm)( "trmm_lu_ker_var2: a1", MR, k_a1112, a1, 1, MR, "%4.1f", "" ); +//PASTEMAC(ch,printm)( "trmm_lu_ker_var2: b1", k_a1112, NR, b1_i, NR, 1, "%4.1f", "" ); + +#if 0 + printf( "tid: %ld m,n,k_iter: %ld %ld %ld\n", tid, m_iter, n_iter, k_iter ); + printf( "tid: %ld trmm_lu_tlb begins at: %ld %ld (n_ut: %ld)\n", + tid, jr_st, ir_st, n_ut_for_me ); +#endif + diff --git a/frame/3/trmm/bli_trmm_rl_ker_var2.c b/frame/3/trmm/bli_trmm_rl_ker_var2.c index 785f2cf5fd..f8d0fc6c85 100644 --- a/frame/3/trmm/bli_trmm_rl_ker_var2.c +++ b/frame/3/trmm/bli_trmm_rl_ker_var2.c @@ -37,11 +37,11 @@ void bli_trmm_rl_ker_var2 ( - const obj_t* a, - const obj_t* b, - const obj_t* c, - const cntx_t* cntx, - const cntl_t* cntl, + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntx_t* cntx, + const cntl_t* cntl, thrinfo_t* thread_par ) { @@ -147,39 +147,40 @@ void bli_trmm_rl_ker_var2 // Compute number of primary and leftover components of the m and n // dimensions. - dim_t n_iter = n / NR; - dim_t n_left = n % NR; + const dim_t n_iter = n / NR + ( n % NR ? 1 : 0 ); + const dim_t n_left = n % NR; - dim_t m_iter = m / MR; - dim_t m_left = m % MR; - - if ( n_left ) ++n_iter; - if ( m_left ) ++m_iter; + const dim_t m_iter = m / MR + ( m % MR ? 1 : 0 ); + const dim_t m_left = m % MR; // Determine some increments used to step through A, B, and C. - inc_t rstep_a = ps_a * dt_size; + const inc_t rstep_a = ps_a * dt_size; - inc_t cstep_b = ps_b * dt_size; + const inc_t cstep_b = ps_b * dt_size; - inc_t rstep_c = rs_c * MR * dt_size; - inc_t cstep_c = cs_c * NR * dt_size; + const inc_t rstep_c = rs_c * MR * dt_size; + const inc_t cstep_c = cs_c * NR * dt_size; - // Save the pack schemas of A and B to the auxinfo_t object. auxinfo_t aux; + + // Save the pack schemas of A and B to the auxinfo_t object. bli_auxinfo_set_schema_a( schema_a, &aux ); bli_auxinfo_set_schema_b( schema_b, &aux ); + // The 'thread' argument points to the thrinfo_t node for the 2nd (jr) + // loop around the microkernel. Here we query the thrinfo_t node for the + // 1st (ir) loop around the microkernel. thrinfo_t* thread = bli_thrinfo_sub_node( thread_par ); thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); - dim_t jr_nt = bli_thrinfo_n_way( thread ); - dim_t jr_tid = bli_thrinfo_work_id( thread ); - dim_t ir_nt = bli_thrinfo_n_way( caucus ); - dim_t ir_tid = bli_thrinfo_work_id( caucus ); + // Query the number of threads and thread ids for each loop. + const dim_t jr_nt = bli_thrinfo_n_way( thread ); + const dim_t jr_tid = bli_thrinfo_work_id( thread ); + const dim_t ir_nt = bli_thrinfo_n_way( caucus ); + const dim_t ir_tid = bli_thrinfo_work_id( caucus ); - dim_t jr_start, jr_end; - dim_t ir_start, ir_end; - dim_t jr_inc, ir_inc; + dim_t jr_start, jr_end, jr_inc; + dim_t ir_start, ir_end, ir_inc; // Note that we partition the 2nd loop into two regions: the rectangular // part of B, and the triangular portion. @@ -207,11 +208,11 @@ void bli_trmm_rl_ker_var2 // Determine the thread range and increment for the 2nd and 1st loops for // the initial rectangular region of B (if it exists). - // NOTE: The definition of bli_thread_range_jrir() will depend on whether + // NOTE: The definition of bli_thread_range_slrr() will depend on whether // slab or round-robin partitioning was requested at configure-time. // NOTE: Parallelism in the 1st loop is disabled for now. - bli_thread_range_jrir( thread, n_iter_rct, 1, FALSE, &jr_start, &jr_end, &jr_inc ); - bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); + bli_thread_range_slrr( thread, n_iter_rct, 1, FALSE, &jr_start, &jr_end, &jr_inc ); + bli_thread_range_slrr( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); // Loop over the n dimension (NR columns at a time). for ( dim_t j = jr_start; j < jr_end; j += jr_inc ) @@ -219,7 +220,7 @@ void bli_trmm_rl_ker_var2 const char* b1 = b_cast + j * cstep_b; char* c1 = c_cast + j * cstep_c; - dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); + const dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); // Initialize our next panel of B to be the current panel of B. const char* b2 = b1; @@ -231,15 +232,15 @@ void bli_trmm_rl_ker_var2 const char* a1 = a_cast + i * rstep_a; char* c11 = c1 + i * rstep_c; - dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); + const dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); // Compute the addresses of the next panels of A and B. const char* a2 = bli_trmm_get_next_a_upanel( a1, rstep_a, ir_inc ); - if ( bli_is_last_iter( i, m_iter, ir_tid, ir_nt ) ) + if ( bli_is_last_iter_slrr( i, m_iter, ir_tid, ir_nt ) ) { a2 = a_cast; b2 = bli_trmm_get_next_b_upanel( b1, cstep_b, jr_inc ); - if ( bli_is_last_iter( j, n_iter, jr_tid, jr_nt ) ) + if ( bli_is_last_iter_slrr( j, n_iter, jr_tid, jr_nt ) ) b2 = b_cast; } @@ -271,7 +272,7 @@ void bli_trmm_rl_ker_var2 // Use round-robin assignment of micropanels to threads in the 2nd and // 1st loops for the remaining triangular region of B (if it exists). - // NOTE: We don't need to call bli_thread_range_jrir_rr() here since we + // NOTE: We don't need to call bli_thread_range_rr() here since we // employ a hack that calls for each thread to execute every iteration // of the jr and ir loops but skip all but the pointer increment for // iterations that are not assigned to it. @@ -285,18 +286,18 @@ void bli_trmm_rl_ker_var2 // Loop over the n dimension (NR columns at a time). for ( dim_t j = jr_start; j < n_iter; ++j ) { - doff_t diagoffb_j = diagoffb - ( doff_t )j*NR; + const doff_t diagoffb_j = diagoffb - ( doff_t )j*NR; // Determine the offset to the beginning of the panel that // was packed so we can index into the corresponding location // in A. Then compute the length of that panel. - dim_t off_b1121 = bli_max( -diagoffb_j, 0 ); - dim_t k_b1121 = k - off_b1121; + const dim_t off_b1121 = bli_max( -diagoffb_j, 0 ); + const dim_t k_b1121 = k - off_b1121; const char* a1 = a_cast; char* c11 = c1; - dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); + const dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); // Initialize our next panel of B to be the current panel of B. const char* b2 = b1; @@ -319,7 +320,7 @@ void bli_trmm_rl_ker_var2 { if ( bli_trmm_my_iter_rr( i, caucus ) ) { - dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); + const dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); const char* a1_i = a1 + off_b1121 * PACKMR * dt_size; diff --git a/frame/3/trmm/bli_trmm_rl_ker_var2b.c b/frame/3/trmm/bli_trmm_rl_ker_var2b.c new file mode 100644 index 0000000000..7f2757c3af --- /dev/null +++ b/frame/3/trmm/bli_trmm_rl_ker_var2b.c @@ -0,0 +1,392 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bli_trmm_rl_ker_var2b + ( + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntx_t* cntx, + const cntl_t* cntl, + thrinfo_t* thread_par + ) +{ + const num_t dt = bli_obj_exec_dt( c ); + const dim_t dt_size = bli_dt_size( dt ); + + doff_t diagoffb = bli_obj_diag_offset( b ); + + const pack_t schema_a = bli_obj_pack_schema( a ); + const pack_t schema_b = bli_obj_pack_schema( b ); + + dim_t m = bli_obj_length( c ); + dim_t n = bli_obj_width( c ); + dim_t k = bli_obj_width( a ); + + const void* buf_a = bli_obj_buffer_at_off( a ); + const inc_t cs_a = bli_obj_col_stride( a ); + const dim_t pd_a = bli_obj_panel_dim( a ); + const inc_t ps_a = bli_obj_panel_stride( a ); + + const void* buf_b = bli_obj_buffer_at_off( b ); + const inc_t rs_b = bli_obj_row_stride( b ); + const dim_t pd_b = bli_obj_panel_dim( b ); + const inc_t ps_b = bli_obj_panel_stride( b ); + + void* buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + // Detach and multiply the scalars attached to A and B. + obj_t scalar_a, scalar_b; + bli_obj_scalar_detach( a, &scalar_a ); + bli_obj_scalar_detach( b, &scalar_b ); + bli_mulsc( &scalar_a, &scalar_b ); + + // Grab the addresses of the internal scalar buffers for the scalar + // merged above and the scalar attached to C. + const void* buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); + const void* buf_beta = bli_obj_internal_scalar_buffer( c ); + + // Alias some constants to simpler names. + const dim_t MR = pd_a; + const dim_t NR = pd_b; + const dim_t PACKMR = cs_a; + const dim_t PACKNR = rs_b; + + // Query the context for the micro-kernel address and cast it to its + // function pointer type. + gemm_ukr_vft gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); + + const void* one = bli_obj_buffer_for_const( dt, &BLIS_ONE ); + const char* a_cast = buf_a; + const char* b_cast = buf_b; + char* c_cast = buf_c; + const char* alpha_cast = buf_alpha; + const char* beta_cast = buf_beta; + + /* + Assumptions/assertions: + rs_a == 1 + cs_a == PACKMR + pd_a == MR + ps_a == stride to next micro-panel of A + rs_b == PACKNR + cs_b == 1 + pd_b == NR + ps_b == stride to next micro-panel of B + rs_c == (no assumptions) + cs_c == (no assumptions) + */ + + // Safety trap: Certain indexing within this macro-kernel does not + // work as intended if both MR and NR are odd. + if ( ( bli_is_odd( PACKMR ) && bli_is_odd( NR ) ) || + ( bli_is_odd( PACKNR ) && bli_is_odd( MR ) ) ) bli_abort(); + + // If any dimension is zero, return immediately. + if ( bli_zero_dim3( m, n, k ) ) return; + + // Safeguard: If the current panel of B is entirely above the diagonal, + // it is implicitly zero. So we do nothing. + if ( bli_is_strictly_above_diag_n( diagoffb, k, n ) ) return; + + // If there is a zero region above where the diagonal of B intersects + // the left edge of the panel, adjust the pointer to A and treat this + // case as if the diagonal offset were zero. Note that we don't need to + // adjust the pointer to B since packm would have simply skipped over + // the region that was not stored. (Note we assume the diagonal offset + // is a multiple of NR; this assumption will hold as long as the cache + // blocksizes KC and NC are each a multiple of NR.) + if ( diagoffb < 0 ) + { + k += diagoffb; + a_cast -= diagoffb * PACKMR * dt_size; + diagoffb = 0; + } + + // If there is a zero region to the right of where the diagonal + // of B intersects the bottom of the panel, shrink it to prevent + // "no-op" iterations from executing. + if ( diagoffb + k < n ) + { + n = diagoffb + k; + } + + // Compute number of primary and leftover components of the m and n + // dimensions. + const dim_t n_iter = n / NR + ( n % NR ? 1 : 0 ); + const dim_t n_left = n % NR; + + const dim_t m_iter = m / MR + ( m % MR ? 1 : 0 ); + const dim_t m_left = m % MR; + + // Computing the number of NR x NR tiles in the k dimension is needed + // when computing the thread ranges below. + const dim_t k_iter = k / NR + ( k % NR ? 1 : 0 ); + + // Determine some increments used to step through A, B, and C. + const inc_t rstep_a = ps_a * dt_size; + + const inc_t cstep_b = ps_b * dt_size; + + const inc_t rstep_c = rs_c * MR * dt_size; + const inc_t cstep_c = cs_c * NR * dt_size; + + auxinfo_t aux; + + // Save the pack schemas of A and B to the auxinfo_t object. + bli_auxinfo_set_schema_a( schema_a, &aux ); + bli_auxinfo_set_schema_b( schema_b, &aux ); + + // The 'thread' argument points to the thrinfo_t node for the 2nd (jr) + // loop around the microkernel while the 'caucus' points to the thrinfo_t + // node for the 1st loop (ir). + thrinfo_t* thread = bli_thrinfo_sub_node( thread_par ); + //thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); + + // Query the number of threads and thread ids for each loop. +#if 0 +{ + const dim_t jr_nt = 17; + const dim_t jr_tid = jr_nt - 1; + + const doff_t m_iter = 10; + const doff_t k_iter = 10; + const doff_t n_iter = 20; + + diagoffb = 30 * NR; +#else + const dim_t jr_nt = bli_thrinfo_n_way( thread ); + const dim_t jr_tid = bli_thrinfo_work_id( thread ); + //const dim_t ir_nt = bli_thrinfo_n_way( caucus ); + //const dim_t ir_tid = bli_thrinfo_work_id( caucus ); +#endif + dim_t jr_st, ir_st; + const dim_t n_ut_for_me + = + bli_thread_range_tlb_trmm_rl( jr_nt, jr_tid, diagoffb, m_iter, n_iter, k_iter, + MR, NR, &jr_st, &ir_st ); + +#if 0 + printf( "tid %ld: final range: jr_st, ir_st: %ld %ld (n_ut_for_me: %ld)\n", + jr_tid, jr_st, ir_st, n_ut_for_me ); + return; +} +const dim_t n_ut_for_me = -1; dim_t jr_st, ir_st; +#endif + + // It's possible that there are so few microtiles relative to the number + // of threads that one or more threads gets no work. If that happens, those + // threads can return early. + if ( n_ut_for_me == 0 ) return; + + // Start the jr/ir loops with the current thread's microtile offsets computed + // by bli_thread_range_tlb_trmm_r(). + dim_t i = ir_st; + dim_t j = jr_st; + + // Initialize a counter to track the number of microtiles computed by the + // current thread. + dim_t ut = 0; + + const char* b1 = b_cast; + + // Get pointers into position by stepping through to the jth micropanel of + // B and jth microtile of C (within the appropriate row of microtiles). + for ( dim_t jj = 0; jj < jr_st; ++jj ) + { + const doff_t diagoffb_jj = diagoffb - ( doff_t )jj*NR; + + if ( bli_intersects_diag_n( diagoffb_jj, k, NR ) ) + { + // Determine the length of the panel that was packed. + const dim_t off_b1121 = bli_max( -diagoffb_jj, 0 ); + const dim_t k_b1121 = k - off_b1121; + + // Compute the panel stride for the current diagonal- + // intersecting micro-panel. + inc_t ps_b_cur = k_b1121 * PACKNR; + ps_b_cur += ( bli_is_odd( ps_b_cur ) ? 1 : 0 ); + ps_b_cur *= dt_size; + + b1 += ps_b_cur; + } + else if ( bli_is_strictly_below_diag_n( diagoffb_jj, k, NR ) ) + { + b1 += cstep_b; + } + } + + // Loop over the n dimension (NR columns at a time). + for ( ; true; ++j ) + { + char* c1 = c_cast + j * cstep_c; + + const doff_t diagoffb_j = diagoffb - ( doff_t )j*NR; + + const dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) + ? NR : n_left ); + + // Determine the offset to and length of the panel that was packed + // so we can index into the corresponding location in A. + const dim_t off_b1121 = bli_max( -diagoffb_j, 0 ); + const dim_t k_b1121 = k - off_b1121; + + // Initialize our next panel of B to be the current panel of B. + const char* b2 = b1; + + bli_auxinfo_set_next_b( b2, &aux ); + + // If the current panel of B intersects the diagonal, scale C + // by beta. If it is strictly below the diagonal, scale by one. + // This allows the current macro-kernel to work for both trmm + // and trmm3. + if ( bli_intersects_diag_n( diagoffb_j, k, NR ) ) + { + // Compute the panel stride for the current diagonal- + // intersecting micro-panel. + inc_t ps_b_cur = k_b1121 * PACKNR; + ps_b_cur += ( bli_is_odd( ps_b_cur ) ? 1 : 0 ); + ps_b_cur *= dt_size; + + // Loop over the m dimension (MR rows at a time). + for ( ; i < m_iter; ++i ) + { + const char* a1 = a_cast + i * rstep_a; + char* c11 = c1 + i * rstep_c; + + const dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) + ? MR : m_left ); + + const char* a1_i = a1 + off_b1121 * PACKMR * dt_size; + + // Compute the addresses of the next panels of A and B. + const char* a2 = bli_trmm_get_next_a_upanel( a1, rstep_a, 1 ); + if ( bli_is_last_iter_sl( i, m_iter ) ) + { + a2 = a_cast; + b2 = bli_trmm_get_next_b_upanel( b1, ps_b_cur, 1 ); + bli_auxinfo_set_next_b( b2, &aux ); + } + + // Save addresses of next panels of A and B to the auxinfo_t + // object. + bli_auxinfo_set_next_a( a2, &aux ); + + // Invoke the gemm micro-kernel. + gemm_ukr + ( + m_cur, + n_cur, + k_b1121, + ( void* )alpha_cast, + ( void* )a1_i, + ( void* )b1, + ( void* )beta_cast, + c11, rs_c, cs_c, + &aux, + ( cntx_t* )cntx + ); + + // Increment the microtile counter and check if the thread is done. + ut += 1; if ( ut == n_ut_for_me ) return; + } + + // Upon reaching the end of the column of microtiles, reset the ir + // loop index so that we're ready to start the next pass through the + // m dimension (i.e., the next jr loop iteration). + i = 0; + + b1 += ps_b_cur; + } + else if ( bli_is_strictly_below_diag_n( diagoffb_j, k, NR ) ) + { + // Loop over the m dimension (MR rows at a time). + for ( ; i < m_iter; ++i ) + { + const char* a1 = a_cast + i * rstep_a; + char* c11 = c1 + i * rstep_c; + + const dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) + ? MR : m_left ); + + // Compute the addresses of the next panels of A and B. + const char* a2 = bli_trmm_get_next_a_upanel( a1, rstep_a, 1 ); + if ( bli_is_last_iter_sl( i, m_iter ) ) + { + a2 = a_cast; + b2 = bli_trmm_get_next_b_upanel( b1, cstep_b, 1 ); + bli_auxinfo_set_next_b( b2, &aux ); + } + + // Save addresses of next panels of A and B to the auxinfo_t + // object. + bli_auxinfo_set_next_a( a2, &aux ); + + // Invoke the gemm micro-kernel. + gemm_ukr + ( + m_cur, + n_cur, + k, + ( void* )alpha_cast, + ( void* )a1, + ( void* )b1, + ( void* )one, + c11, rs_c, cs_c, + &aux, + ( cntx_t* )cntx + ); + + // Increment the microtile counter and check if the thread is done. + ut += 1; if ( ut == n_ut_for_me ) return; + } + + // Upon reaching the end of the column of microtiles, reset the ir + // loop index so that we're ready to start the next pass through the + // m dimension (i.e., the next jr loop iteration). + i = 0; + + b1 += cstep_b; + } + } +} + +//PASTEMAC(ch,fprintm)( stdout, "trmm_rl_ker_var2: a1", MR, k_b1121, a1, 1, MR, "%4.1f", "" ); +//PASTEMAC(ch,fprintm)( stdout, "trmm_rl_ker_var2: b1", k_b1121, NR, b1_i, NR, 1, "%4.1f", "" ); + diff --git a/frame/3/trmm/bli_trmm_ru_ker_var2.c b/frame/3/trmm/bli_trmm_ru_ker_var2.c index ca27caef10..a031b67947 100644 --- a/frame/3/trmm/bli_trmm_ru_ker_var2.c +++ b/frame/3/trmm/bli_trmm_ru_ker_var2.c @@ -37,11 +37,11 @@ void bli_trmm_ru_ker_var2 ( - const obj_t* a, - const obj_t* b, - const obj_t* c, - const cntx_t* cntx, - const cntl_t* cntl, + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntx_t* cntx, + const cntl_t* cntl, thrinfo_t* thread_par ) { @@ -148,25 +148,23 @@ void bli_trmm_ru_ker_var2 // Compute number of primary and leftover components of the m and n // dimensions. - dim_t n_iter = n / NR; - dim_t n_left = n % NR; + const dim_t n_iter = n / NR + ( n % NR ? 1 : 0 ); + const dim_t n_left = n % NR; - dim_t m_iter = m / MR; - dim_t m_left = m % MR; - - if ( n_left ) ++n_iter; - if ( m_left ) ++m_iter; + const dim_t m_iter = m / MR + ( m % MR ? 1 : 0 ); + const dim_t m_left = m % MR; // Determine some increments used to step through A, B, and C. - inc_t rstep_a = ps_a * dt_size; + const inc_t rstep_a = ps_a * dt_size; - inc_t cstep_b = ps_b * dt_size; + const inc_t cstep_b = ps_b * dt_size; - inc_t rstep_c = rs_c * MR * dt_size; - inc_t cstep_c = cs_c * NR * dt_size; + const inc_t rstep_c = rs_c * MR * dt_size; + const inc_t cstep_c = cs_c * NR * dt_size; - // Save the pack schemas of A and B to the auxinfo_t object. auxinfo_t aux; + + // Save the pack schemas of A and B to the auxinfo_t object. bli_auxinfo_set_schema_a( schema_a, &aux ); bli_auxinfo_set_schema_b( schema_b, &aux ); @@ -177,14 +175,13 @@ void bli_trmm_ru_ker_var2 thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); // Query the number of threads and thread ids for each loop. - dim_t jr_nt = bli_thrinfo_n_way( thread ); - dim_t jr_tid = bli_thrinfo_work_id( thread ); - dim_t ir_nt = bli_thrinfo_n_way( caucus ); - dim_t ir_tid = bli_thrinfo_work_id( caucus ); + //const dim_t jr_nt = bli_thrinfo_n_way( thread ); + //const dim_t jr_tid = bli_thrinfo_work_id( thread ); + const dim_t ir_nt = bli_thrinfo_n_way( caucus ); + const dim_t ir_tid = bli_thrinfo_work_id( caucus ); - dim_t jr_start, jr_end; - dim_t ir_start, ir_end; - dim_t jr_inc, ir_inc; + dim_t jr_start, jr_end, jr_inc; + dim_t ir_start, ir_end, ir_inc; // Note that we partition the 2nd loop into two regions: the triangular // part of C, and the rectangular portion. @@ -212,7 +209,7 @@ void bli_trmm_ru_ker_var2 // Use round-robin assignment of micropanels to threads in the 2nd and // 1st loops for the initial triangular region of B (if it exists). - // NOTE: We don't need to call bli_thread_range_jrir_rr() here since we + // NOTE: We don't need to call bli_thread_range_rr() here since we // employ a hack that calls for each thread to execute every iteration // of the jr and ir loops but skip all but the pointer increment for // iterations that are not assigned to it. @@ -223,17 +220,18 @@ void bli_trmm_ru_ker_var2 // Loop over the n dimension (NR columns at a time). for ( dim_t j = 0; j < n_iter_tri; ++j ) { - doff_t diagoffb_j = diagoffb - ( doff_t )j*NR; + const doff_t diagoffb_j = diagoffb - ( doff_t )j*NR; // Determine the offset to and length of the panel that was packed // so we can index into the corresponding location in A. - dim_t off_b0111 = 0; - dim_t k_b0111 = bli_min( k, -diagoffb_j + NR ); + const dim_t off_b0111 = 0; + const dim_t k_b0111 = bli_min( k, -diagoffb_j + NR ); const char* a1 = a_cast; char* c11 = c1; - dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); + const dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) + ? NR : n_left ); // Initialize our next panel of B to be the current panel of B. const char* b2 = b1; @@ -256,7 +254,8 @@ void bli_trmm_ru_ker_var2 { if ( bli_trmm_my_iter_rr( i, caucus ) ) { - dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); + const dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) + ? MR : m_left ); const char* a1_i = a1 + off_b0111 * PACKMR * dt_size; @@ -266,8 +265,6 @@ void bli_trmm_ru_ker_var2 { a2 = a_cast; b2 = b1; - if ( bli_is_last_iter_rr( j, n_iter, jr_tid, jr_nt ) ) - b2 = b_cast; } // Save addresses of next panels of A and B to the auxinfo_t @@ -307,11 +304,11 @@ void bli_trmm_ru_ker_var2 // Determine the thread range and increment for the 2nd and 1st loops for // the remaining rectangular region of B. - // NOTE: The definition of bli_thread_range_jrir() will depend on whether + // NOTE: The definition of bli_thread_range_slrr() will depend on whether // slab or round-robin partitioning was requested at configure-time. // NOTE: Parallelism in the 1st loop is disabled for now. - bli_thread_range_jrir( thread, n_iter_rct, 1, FALSE, &jr_start, &jr_end, &jr_inc ); - bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); + bli_thread_range_slrr( thread, n_iter_rct, 1, FALSE, &jr_start, &jr_end, &jr_inc ); + bli_thread_range_slrr( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); // Advance the start and end iteration offsets for the rectangular region // by the number of iterations used for the triangular region. @@ -332,7 +329,8 @@ void bli_trmm_ru_ker_var2 b1 = b_cast + (j-jb0) * cstep_b; c1 = c_cast + j * cstep_c; - dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); + const dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) + ? NR : n_left ); // Initialize our next panel of B to be the current panel of B. const char* b2 = b1; @@ -348,16 +346,15 @@ void bli_trmm_ru_ker_var2 const char* a1 = a_cast + i * rstep_a; char* c11 = c1 + i * rstep_c; - dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); + const dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) + ? MR : m_left ); // Compute the addresses of the next panels of A and B. const char* a2 = bli_trmm_get_next_a_upanel( a1, rstep_a, ir_inc ); - if ( bli_is_last_iter( i, m_iter, ir_tid, ir_nt ) ) + if ( bli_is_last_iter_slrr( i, m_iter, ir_tid, ir_nt ) ) { a2 = a_cast; b2 = bli_trmm_get_next_b_upanel( b1, cstep_b, jr_inc ); - if ( bli_is_last_iter( j, n_iter, jr_tid, jr_nt ) ) - b2 = b_cast; } // Save addresses of next panels of A and B to the auxinfo_t diff --git a/frame/3/trmm/bli_trmm_ru_ker_var2b.c b/frame/3/trmm/bli_trmm_ru_ker_var2b.c new file mode 100644 index 0000000000..8aae2386aa --- /dev/null +++ b/frame/3/trmm/bli_trmm_ru_ker_var2b.c @@ -0,0 +1,390 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bli_trmm_ru_ker_var2b + ( + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntx_t* cntx, + const cntl_t* cntl, + thrinfo_t* thread_par + ) +{ + const num_t dt = bli_obj_exec_dt( c ); + const dim_t dt_size = bli_dt_size( dt ); + + doff_t diagoffb = bli_obj_diag_offset( b ); + + const pack_t schema_a = bli_obj_pack_schema( a ); + const pack_t schema_b = bli_obj_pack_schema( b ); + + dim_t m = bli_obj_length( c ); + dim_t n = bli_obj_width( c ); + dim_t k = bli_obj_width( a ); + + const void* buf_a = bli_obj_buffer_at_off( a ); + const inc_t cs_a = bli_obj_col_stride( a ); + const dim_t pd_a = bli_obj_panel_dim( a ); + const inc_t ps_a = bli_obj_panel_stride( a ); + + const void* buf_b = bli_obj_buffer_at_off( b ); + const inc_t rs_b = bli_obj_row_stride( b ); + const dim_t pd_b = bli_obj_panel_dim( b ); + const inc_t ps_b = bli_obj_panel_stride( b ); + + void* buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + // Detach and multiply the scalars attached to A and B. + obj_t scalar_a, scalar_b; + bli_obj_scalar_detach( a, &scalar_a ); + bli_obj_scalar_detach( b, &scalar_b ); + bli_mulsc( &scalar_a, &scalar_b ); + + // Grab the addresses of the internal scalar buffers for the scalar + // merged above and the scalar attached to C. + const void* buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); + const void* buf_beta = bli_obj_internal_scalar_buffer( c ); + + // Alias some constants to simpler names. + const dim_t MR = pd_a; + const dim_t NR = pd_b; + const dim_t PACKMR = cs_a; + const dim_t PACKNR = rs_b; + + // Query the context for the micro-kernel address and cast it to its + // function pointer type. + gemm_ukr_vft gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); + + const void* one = bli_obj_buffer_for_const( dt, &BLIS_ONE ); + const char* a_cast = buf_a; + const char* b_cast = buf_b; + char* c_cast = buf_c; + const char* alpha_cast = buf_alpha; + const char* beta_cast = buf_beta; + + /* + Assumptions/assertions: + rs_a == 1 + cs_a == PACKMR + pd_a == MR + ps_a == stride to next micro-panel of A + rs_b == PACKNR + cs_b == 1 + pd_b == NR + ps_b == stride to next micro-panel of B + rs_c == (no assumptions) + cs_c == (no assumptions) + */ + + // Safety trap: Certain indexing within this macro-kernel does not + // work as intended if both MR and NR are odd. + if ( ( bli_is_odd( PACKMR ) && bli_is_odd( NR ) ) || + ( bli_is_odd( PACKNR ) && bli_is_odd( MR ) ) ) bli_abort(); + + // If any dimension is zero, return immediately. + if ( bli_zero_dim3( m, n, k ) ) return; + + // Safeguard: If the current panel of B is entirely below its diagonal, + // it is implicitly zero. So we do nothing. + if ( bli_is_strictly_below_diag_n( diagoffb, k, n ) ) return; + + // If there is a zero region to the left of where the diagonal of B + // intersects the top edge of the panel, adjust the pointer to C and + // treat this case as if the diagonal offset were zero. This skips over + // the region that was not packed. (Note we assume the diagonal offset + // is a multiple of NR; this assumption will hold as long as the cache + // blocksizes KC and NC are each a multiple of NR.) + if ( diagoffb > 0 ) + { + n -= diagoffb; + c_cast += diagoffb * cs_c * dt_size; + diagoffb = 0; + } + + // If there is a zero region below where the diagonal of B intersects the + // right side of the block, shrink it to prevent "no-op" iterations from + // executing. + if ( -diagoffb + n < k ) + { + k = -diagoffb + n; + } + + // Compute number of primary and leftover components of the m and n + // dimensions. + const dim_t n_iter = n / NR + ( n % NR ? 1 : 0 ); + const dim_t n_left = n % NR; + + const dim_t m_iter = m / MR + ( m % MR ? 1 : 0 ); + const dim_t m_left = m % MR; + + // Computing the number of NR x NR tiles in the k dimension is needed + // when computing the thread ranges below. + const dim_t k_iter = k / NR + ( k % NR ? 1 : 0 ); + + // Determine some increments used to step through A, B, and C. + const inc_t rstep_a = ps_a * dt_size; + + const inc_t cstep_b = ps_b * dt_size; + + const inc_t rstep_c = rs_c * MR * dt_size; + const inc_t cstep_c = cs_c * NR * dt_size; + + auxinfo_t aux; + + // Save the pack schemas of A and B to the auxinfo_t object. + bli_auxinfo_set_schema_a( schema_a, &aux ); + bli_auxinfo_set_schema_b( schema_b, &aux ); + + // The 'thread' argument points to the thrinfo_t node for the 2nd (jr) + // loop around the microkernel. Here we query the thrinfo_t node for the + // 1st (ir) loop around the microkernel. + thrinfo_t* thread = bli_thrinfo_sub_node( thread_par ); + //thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); + + // Query the number of threads and thread ids for each loop. +#if 0 +{ + const dim_t jr_nt = 1; + const dim_t jr_tid = 0; //jr_nt - 1; + + const doff_t m_iter = 10; + const doff_t k_iter = 10; + const doff_t n_iter = 20; + + diagoffb = 0 * NR; +#else + const dim_t jr_nt = bli_thrinfo_n_way( thread ); + const dim_t jr_tid = bli_thrinfo_work_id( thread ); + //const dim_t ir_nt = bli_thrinfo_n_way( caucus ); + //const dim_t ir_tid = bli_thrinfo_work_id( caucus ); +#endif + dim_t jr_st, ir_st; + const dim_t n_ut_for_me + = + bli_thread_range_tlb_trmm_ru( jr_nt, jr_tid, diagoffb, m_iter, n_iter, k_iter, + MR, NR, &jr_st, &ir_st ); + +#if 0 + printf( "tid %ld: final range: jr_st, ir_st: %ld %ld (n_ut_for_me: %ld)\n", + jr_tid, jr_st, ir_st, n_ut_for_me ); + return; +} +const dim_t n_ut_for_me = -1; dim_t jr_st, ir_st; +#endif + + // It's possible that there are so few microtiles relative to the number + // of threads that one or more threads gets no work. If that happens, those + // threads can return early. + if ( n_ut_for_me == 0 ) return; + + // Start the jr/ir loops with the current thread's microtile offsets computed + // by bli_thread_range_tlb_trmm_r(). + dim_t i = ir_st; + dim_t j = jr_st; + + // Initialize a counter to track the number of microtiles computed by the + // current thread. + dim_t ut = 0; + + const char* b1 = b_cast; + + // Get pointers into position by stepping through to the jth micropanel of + // B and jth microtile of C (within the appropriate row of microtiles). + for ( dim_t jj = 0; jj < jr_st; ++jj ) + { + const doff_t diagoffb_jj = diagoffb - ( doff_t )jj*NR; + + if ( bli_intersects_diag_n( diagoffb_jj, k, NR ) ) + { + // Determine the length of the panel that was packed. + const dim_t k_b0111 = bli_min( k, -diagoffb_jj + NR ); + + // Compute the panel stride for the current diagonal- + // intersecting micro-panel. + inc_t ps_b_cur = k_b0111 * PACKNR; + ps_b_cur += ( bli_is_odd( ps_b_cur ) ? 1 : 0 ); + ps_b_cur *= dt_size; + + b1 += ps_b_cur; + } + else if ( bli_is_strictly_above_diag_n( diagoffb_jj, k, NR ) ) + { + b1 += cstep_b; + } + } + + // Loop over the n dimension (NR columns at a time). + for ( ; true; ++j ) + { + char* c1 = c_cast + j * cstep_c; + + const doff_t diagoffb_j = diagoffb - ( doff_t )j*NR; + + const dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) + ? NR : n_left ); + + // Determine the offset to and length of the panel that was packed + // so we can index into the corresponding location in A. + const dim_t off_b0111 = 0; + const dim_t k_b0111 = bli_min( k, -diagoffb_j + NR ); + + // Initialize our next panel of B to be the current panel of B. + const char* b2 = b1; + + bli_auxinfo_set_next_b( b2, &aux ); + + // If the current panel of B intersects the diagonal, scale C + // by beta. If it is strictly above the diagonal, scale by one. + // This allows the current macro-kernel to work for both trmm + // and trmm3. + if ( bli_intersects_diag_n( diagoffb_j, k, NR ) ) + { + // Compute the panel stride for the current diagonal- + // intersecting micro-panel. + inc_t ps_b_cur = k_b0111 * PACKNR; + ps_b_cur += ( bli_is_odd( ps_b_cur ) ? 1 : 0 ); + ps_b_cur *= dt_size; + + // Loop over the m dimension (MR rows at a time). + for ( ; i < m_iter; ++i ) + { + const char* a1 = a_cast + i * rstep_a; + char* c11 = c1 + i * rstep_c; + + const dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) + ? MR : m_left ); + + const char* a1_i = a1 + off_b0111 * PACKMR * dt_size; + + // Compute the addresses of the next panels of A and B. + const char* a2 = bli_trmm_get_next_a_upanel( a1, rstep_a, 1 ); + if ( bli_is_last_iter_sl( i, m_iter ) ) + { + a2 = a_cast; + b2 = bli_trmm_get_next_b_upanel( b1, ps_b_cur, 1 ); + bli_auxinfo_set_next_b( b2, &aux ); + } + + // Save addresses of next panels of A and B to the auxinfo_t + // object. + bli_auxinfo_set_next_a( a2, &aux ); + + // Invoke the gemm micro-kernel. + gemm_ukr + ( + m_cur, + n_cur, + k_b0111, + ( void* )alpha_cast, + ( void* )a1_i, + ( void* )b1, + ( void* )beta_cast, + c11, rs_c, cs_c, + &aux, + ( cntx_t* )cntx + ); + + // Increment the microtile counter and check if the thread is done. + ut += 1; if ( ut == n_ut_for_me ) return; + } + + // Upon reaching the end of the column of microtiles, reset the ir + // loop index so that we're ready to start the next pass through the + // m dimension (i.e., the next jr loop iteration). + i = 0; + + b1 += ps_b_cur; + } + else if ( bli_is_strictly_above_diag_n( diagoffb_j, k, NR ) ) + { + // Loop over the m dimension (MR rows at a time). + for ( ; i < m_iter; ++i ) + { + const char* a1 = a_cast + i * rstep_a; + char* c11 = c1 + i * rstep_c; + + const dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) + ? MR : m_left ); + + // Compute the addresses of the next panels of A and B. + const char* a2 = bli_trmm_get_next_a_upanel( a1, rstep_a, 1 ); + if ( bli_is_last_iter_sl( i, m_iter ) ) + { + a2 = a_cast; + b2 = bli_trmm_get_next_b_upanel( b1, cstep_b, 1 ); + bli_auxinfo_set_next_b( b2, &aux ); + } + + // Save addresses of next panels of A and B to the auxinfo_t + // object. + bli_auxinfo_set_next_a( a2, &aux ); + + // Invoke the gemm micro-kernel. + gemm_ukr + ( + m_cur, + n_cur, + k, + ( void* )alpha_cast, + ( void* )a1, + ( void* )b1, + ( void* )one, + c11, rs_c, cs_c, + &aux, + ( cntx_t* )cntx + ); + + // Increment the microtile counter and check if the thread is done. + ut += 1; if ( ut == n_ut_for_me ) return; + } + + // Upon reaching the end of the column of microtiles, reset the ir + // loop index so that we're ready to start the next pass through the + // m dimension (i.e., the next jr loop iteration). + i = 0; + + b1 += cstep_b; + } + } +} + +//PASTEMAC(ch,fprintm)( stdout, "trmm_ru_ker_var2: a1", MR, k_b0111, a1, 1, MR, "%4.1f", "" ); +//PASTEMAC(ch,fprintm)( stdout, "trmm_ru_ker_var2: b1", k_b0111, NR, b1_i, NR, 1, "%4.1f", "" ); + diff --git a/frame/3/trmm/bli_trmm_var.h b/frame/3/trmm/bli_trmm_var.h index f8c3d7ee20..0a605ba86a 100644 --- a/frame/3/trmm/bli_trmm_var.h +++ b/frame/3/trmm/bli_trmm_var.h @@ -43,54 +43,23 @@ \ void PASTEMAC0(opname) \ ( \ - const obj_t* a, \ - const obj_t* b, \ - const obj_t* c, \ - const cntx_t* cntx, \ - const cntl_t* cntl, \ - thrinfo_t* thread \ + const obj_t* a, \ + const obj_t* b, \ + const obj_t* c, \ + const cntx_t* cntx, \ + const cntl_t* cntl, \ + thrinfo_t* thread_par \ ); -//GENPROT( trmm_blk_var1 ) -//GENPROT( trmm_blk_var2 ) -//GENPROT( trmm_blk_var3 ) - GENPROT( trmm_xx_ker_var2 ) - GENPROT( trmm_ll_ker_var2 ) GENPROT( trmm_lu_ker_var2 ) GENPROT( trmm_rl_ker_var2 ) GENPROT( trmm_ru_ker_var2 ) - -// -// Prototype BLAS-like interfaces with void pointer operands. -// - -#undef GENTPROT -#define GENTPROT( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - doff_t diagoff, \ - pack_t schema_a, \ - pack_t schema_b, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - void* alpha, \ - void* a, inc_t cs_a, \ - dim_t pd_a, inc_t ps_a, \ - void* b, inc_t rs_b, \ - dim_t pd_b, inc_t ps_b, \ - void* beta, \ - void* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - thrinfo_t* thread \ - ); - -INSERT_GENTPROT_BASIC0( trmm_ll_ker_var2 ) -INSERT_GENTPROT_BASIC0( trmm_lu_ker_var2 ) -INSERT_GENTPROT_BASIC0( trmm_rl_ker_var2 ) -INSERT_GENTPROT_BASIC0( trmm_ru_ker_var2 ) +GENPROT( trmm_xx_ker_var2b ) +GENPROT( trmm_ll_ker_var2b ) +GENPROT( trmm_lu_ker_var2b ) +GENPROT( trmm_rl_ker_var2b ) +GENPROT( trmm_ru_ker_var2b ) diff --git a/frame/3/trmm/bli_trmm_xx_ker_var2.c b/frame/3/trmm/bli_trmm_xx_ker_var2.c index 60030bf4aa..918b8f973e 100644 --- a/frame/3/trmm/bli_trmm_xx_ker_var2.c +++ b/frame/3/trmm/bli_trmm_xx_ker_var2.c @@ -43,12 +43,12 @@ static l3_var_oft vars[2][2] = void bli_trmm_xx_ker_var2 ( - const obj_t* a, - const obj_t* b, - const obj_t* c, - const cntx_t* cntx, - const cntl_t* cntl, - thrinfo_t* thread + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntx_t* cntx, + const cntl_t* cntl, + thrinfo_t* thread_par ) { dim_t side; @@ -81,7 +81,7 @@ void bli_trmm_xx_ker_var2 c, cntx, cntl, - thread + thread_par ); } diff --git a/frame/3/trmm/bli_trmm_xx_ker_var2b.c b/frame/3/trmm/bli_trmm_xx_ker_var2b.c new file mode 100644 index 0000000000..57894165ce --- /dev/null +++ b/frame/3/trmm/bli_trmm_xx_ker_var2b.c @@ -0,0 +1,87 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +static l3_var_oft vars[2][2] = +{ + { bli_trmm_ll_ker_var2b, bli_trmm_lu_ker_var2b }, + { bli_trmm_rl_ker_var2b, bli_trmm_ru_ker_var2b } +}; + +void bli_trmm_xx_ker_var2b + ( + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntx_t* cntx, + const cntl_t* cntl, + thrinfo_t* thread_par + ) +{ + dim_t side; + dim_t uplo; + + // Set two bools: one based on the implied side parameter (the structure + // of the root object) and one based on the uplo field of the triangular + // matrix's root object (whether that is matrix A or matrix B). + if ( bli_obj_root_is_triangular( a ) ) + { + side = 0; + if ( bli_obj_root_is_lower( a ) ) uplo = 0; + else uplo = 1; + } + else // if ( bli_obj_root_is_triangular( b ) ) + { + side = 1; + if ( bli_obj_root_is_lower( b ) ) uplo = 0; + else uplo = 1; + } + + // Index into the variant array to extract the correct function pointer. + l3_var_oft f = vars[side][uplo]; + + // Call the macrokernel. + f + ( + a, + b, + c, + cntx, + cntl, + thread_par + ); +} + diff --git a/frame/3/trmm/other/bli_trmm_rl_ker_var2.c.prev b/frame/3/trmm/other/bli_trmm_rl_ker_var2.c.prev new file mode 100644 index 0000000000..5aebe23c1c --- /dev/null +++ b/frame/3/trmm/other/bli_trmm_rl_ker_var2.c.prev @@ -0,0 +1,371 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bli_trmm_rl_ker_var2 + ( + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntx_t* cntx, + const cntl_t* cntl, + thrinfo_t* thread_par + ) +{ + const num_t dt = bli_obj_exec_dt( c ); + const dim_t dt_size = bli_dt_size( dt ); + + doff_t diagoffb = bli_obj_diag_offset( b ); + + const pack_t schema_a = bli_obj_pack_schema( a ); + const pack_t schema_b = bli_obj_pack_schema( b ); + + dim_t m = bli_obj_length( c ); + dim_t n = bli_obj_width( c ); + dim_t k = bli_obj_width( a ); + + const void* buf_a = bli_obj_buffer_at_off( a ); + const inc_t cs_a = bli_obj_col_stride( a ); + const dim_t pd_a = bli_obj_panel_dim( a ); + const inc_t ps_a = bli_obj_panel_stride( a ); + + const void* buf_b = bli_obj_buffer_at_off( b ); + const inc_t rs_b = bli_obj_row_stride( b ); + const dim_t pd_b = bli_obj_panel_dim( b ); + const inc_t ps_b = bli_obj_panel_stride( b ); + + void* buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + // Detach and multiply the scalars attached to A and B. + obj_t scalar_a, scalar_b; + bli_obj_scalar_detach( a, &scalar_a ); + bli_obj_scalar_detach( b, &scalar_b ); + bli_mulsc( &scalar_a, &scalar_b ); + + // Grab the addresses of the internal scalar buffers for the scalar + // merged above and the scalar attached to C. + const void* buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); + const void* buf_beta = bli_obj_internal_scalar_buffer( c ); + + // Alias some constants to simpler names. + const dim_t MR = pd_a; + const dim_t NR = pd_b; + const dim_t PACKMR = cs_a; + const dim_t PACKNR = rs_b; + + // Query the context for the micro-kernel address and cast it to its + // function pointer type. + gemm_ukr_vft gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); + + const void* one = bli_obj_buffer_for_const( dt, &BLIS_ONE ); + const char* a_cast = buf_a; + const char* b_cast = buf_b; + char* c_cast = buf_c; + const char* alpha_cast = buf_alpha; + const char* beta_cast = buf_beta; + + /* + Assumptions/assertions: + rs_a == 1 + cs_a == PACKMR + pd_a == MR + ps_a == stride to next micro-panel of A + rs_b == PACKNR + cs_b == 1 + pd_b == NR + ps_b == stride to next micro-panel of B + rs_c == (no assumptions) + cs_c == (no assumptions) + */ + + // Safety trap: Certain indexing within this macro-kernel does not + // work as intended if both MR and NR are odd. + if ( ( bli_is_odd( PACKMR ) && bli_is_odd( NR ) ) || + ( bli_is_odd( PACKNR ) && bli_is_odd( MR ) ) ) bli_abort(); + + // If any dimension is zero, return immediately. + if ( bli_zero_dim3( m, n, k ) ) return; + + // Safeguard: If the current panel of B is entirely above the diagonal, + // it is implicitly zero. So we do nothing. + if ( bli_is_strictly_above_diag_n( diagoffb, k, n ) ) return; + + // If there is a zero region above where the diagonal of B intersects + // the left edge of the panel, adjust the pointer to A and treat this + // case as if the diagonal offset were zero. Note that we don't need to + // adjust the pointer to B since packm would have simply skipped over + // the region that was not stored. + if ( diagoffb < 0 ) + { + k += diagoffb; + a_cast -= diagoffb * PACKMR * dt_size; + diagoffb = 0; + } + + // If there is a zero region to the right of where the diagonal + // of B intersects the bottom of the panel, shrink it to prevent + // "no-op" iterations from executing. + if ( diagoffb + k < n ) + { + n = diagoffb + k; + } + + // Compute number of primary and leftover components of the m and n + // dimensions. + dim_t n_iter = n / NR; + dim_t n_left = n % NR; + + dim_t m_iter = m / MR; + dim_t m_left = m % MR; + + if ( n_left ) ++n_iter; + if ( m_left ) ++m_iter; + + // Determine some increments used to step through A, B, and C. + inc_t rstep_a = ps_a * dt_size; + + inc_t cstep_b = ps_b * dt_size; + + inc_t rstep_c = rs_c * MR * dt_size; + inc_t cstep_c = cs_c * NR * dt_size; + + // Save the pack schemas of A and B to the auxinfo_t object. + auxinfo_t aux; + bli_auxinfo_set_schema_a( schema_a, &aux ); + bli_auxinfo_set_schema_b( schema_b, &aux ); + + thrinfo_t* thread = bli_thrinfo_sub_node( thread_par ); + thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); + + dim_t jr_nt = bli_thrinfo_n_way( thread ); + dim_t jr_tid = bli_thrinfo_work_id( thread ); + dim_t ir_nt = bli_thrinfo_n_way( caucus ); + dim_t ir_tid = bli_thrinfo_work_id( caucus ); + + dim_t jr_start, jr_end; + dim_t ir_start, ir_end; + dim_t jr_inc, ir_inc; + + // Note that we partition the 2nd loop into two regions: the rectangular + // part of B, and the triangular portion. + dim_t n_iter_rct; + dim_t n_iter_tri; + + if ( bli_is_strictly_below_diag_n( diagoffb, m, n ) ) + { + // If the entire panel of B does not intersect the diagonal, there is + // no triangular region, and therefore we can skip the second set of + // loops. + n_iter_rct = n_iter; + n_iter_tri = 0; + } + else + { + // If the panel of B does intersect the diagonal, compute the number of + // iterations in the rectangular region by dividing NR into the diagonal + // offset. (There should never be any remainder in this division.) The + // number of iterations in the triangular (or trapezoidal) region is + // computed as the remaining number of iterations in the n dimension. + n_iter_rct = diagoffb / NR; + n_iter_tri = n_iter - n_iter_rct; + } + + // Determine the thread range and increment for the 2nd and 1st loops for + // the initial rectangular region of B (if it exists). + // NOTE: The definition of bli_thread_range_jrir() will depend on whether + // slab or round-robin partitioning was requested at configure-time. + // NOTE: Parallelism in the 1st loop is disabled for now. + bli_thread_range_jrir( thread, n_iter_rct, 1, FALSE, &jr_start, &jr_end, &jr_inc ); + bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); + + // Loop over the n dimension (NR columns at a time). + for ( dim_t j = jr_start; j < jr_end; j += jr_inc ) + { + const char* b1 = b_cast + j * cstep_b; + char* c1 = c_cast + j * cstep_c; + + dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); + + // Initialize our next panel of B to be the current panel of B. + const char* b2 = b1; + + { + // Loop over the m dimension (MR rows at a time). + for ( dim_t i = ir_start; i < ir_end; i += ir_inc ) + { + const char* a1 = a_cast + i * rstep_a; + char* c11 = c1 + i * rstep_c; + + dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); + + // Compute the addresses of the next panels of A and B. + const char* a2 = bli_trmm_get_next_a_upanel( a1, rstep_a, ir_inc ); + if ( bli_is_last_iter( i, m_iter, ir_tid, ir_nt ) ) + { + a2 = a_cast; + b2 = bli_trmm_get_next_b_upanel( b1, cstep_b, jr_inc ); + if ( bli_is_last_iter( j, n_iter, jr_tid, jr_nt ) ) + b2 = b_cast; + } + + // Save addresses of next panels of A and B to the auxinfo_t + // object. + bli_auxinfo_set_next_a( a2, &aux ); + bli_auxinfo_set_next_b( b2, &aux ); + + // Invoke the gemm micro-kernel. + gemm_ukr + ( + m_cur, + n_cur, + k, + ( void* )alpha_cast, + ( void* )a1, + ( void* )b1, + ( void* )one, + c11, rs_c, cs_c, + &aux, + ( cntx_t* )cntx + ); + } + } + } + + // If there is no triangular region, then we're done. + if ( n_iter_tri == 0 ) return; + + // Use round-robin assignment of micropanels to threads in the 2nd and + // 1st loops for the remaining triangular region of B (if it exists). + // NOTE: We don't need to call bli_thread_range_jrir_rr() here since we + // employ a hack that calls for each thread to execute every iteration + // of the jr and ir loops but skip all but the pointer increment for + // iterations that are not assigned to it. + + // Advance the starting b1 and c1 pointers to the positions corresponding + // to the start of the triangular region of B. + jr_start = n_iter_rct; + const char* b1 = b_cast + jr_start * cstep_b; + char* c1 = c_cast + jr_start * cstep_c; + + // Loop over the n dimension (NR columns at a time). + for ( dim_t j = jr_start; j < n_iter; ++j ) + { + doff_t diagoffb_j = diagoffb - ( doff_t )j*NR; + + // Determine the offset to the beginning of the panel that + // was packed so we can index into the corresponding location + // in A. Then compute the length of that panel. + dim_t off_b1121 = bli_max( -diagoffb_j, 0 ); + dim_t k_b1121 = k - off_b1121; + + const char* a1 = a_cast; + char* c11 = c1; + + dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); + + // Initialize our next panel of B to be the current panel of B. + const char* b2 = b1; + + // If the current panel of B intersects the diagonal, scale C + // by beta. If it is strictly below the diagonal, scale by one. + // This allows the current macro-kernel to work for both trmm + // and trmm3. + { + // Compute the panel stride for the current diagonal- + // intersecting micro-panel. + inc_t ps_b_cur = k_b1121 * PACKNR; + ps_b_cur += ( bli_is_odd( ps_b_cur ) ? 1 : 0 ); + ps_b_cur *= dt_size; + + if ( bli_trmm_my_iter_rr( j, thread ) ) { + + // Loop over the m dimension (MR rows at a time). + for ( dim_t i = 0; i < m_iter; ++i ) + { + if ( bli_trmm_my_iter_rr( i, caucus ) ) { + + dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); + + const char* a1_i = a1 + off_b1121 * PACKMR * dt_size; + + // Compute the addresses of the next panels of A and B. + const char* a2 = a1; + if ( bli_is_last_iter_rr( i, m_iter, 0, 1 ) ) + { + a2 = a_cast; + b2 = b1; + if ( bli_is_last_iter_rr( j, n_iter, jr_tid, jr_nt ) ) + b2 = b_cast; + } + + // Save addresses of next panels of A and B to the auxinfo_t + // object. + bli_auxinfo_set_next_a( a2, &aux ); + bli_auxinfo_set_next_b( b2, &aux ); + + // Invoke the gemm micro-kernel. + gemm_ukr + ( + m_cur, + n_cur, + k_b1121, + ( void* )alpha_cast, + ( void* )a1_i, + ( void* )b1, + ( void* )beta_cast, + c11, rs_c, cs_c, + &aux, + ( cntx_t* )cntx + ); + } + + a1 += rstep_a; + c11 += rstep_c; + } + } + + b1 += ps_b_cur; + } + + c1 += cstep_c; + } +} + +//PASTEMAC(ch,fprintm)( stdout, "trmm_rl_ker_var2: a1", MR, k_b1121, a1, 1, MR, "%4.1f", "" ); +//PASTEMAC(ch,fprintm)( stdout, "trmm_rl_ker_var2: b1", k_b1121, NR, b1_i, NR, 1, "%4.1f", "" ); + diff --git a/frame/3/trmm/other/bli_trmm_rl_ker_var2.c.unified b/frame/3/trmm/other/bli_trmm_rl_ker_var2.c.unified new file mode 100644 index 0000000000..7d2aabaa4b --- /dev/null +++ b/frame/3/trmm/other/bli_trmm_rl_ker_var2.c.unified @@ -0,0 +1,324 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bli_trmm_rl_ker_var2 + ( + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntx_t* cntx, + const cntl_t* cntl, + thrinfo_t* thread_par + ) +{ + const num_t dt = bli_obj_exec_dt( c ); + const dim_t dt_size = bli_dt_size( dt ); + + doff_t diagoffb = bli_obj_diag_offset( b ); + + const pack_t schema_a = bli_obj_pack_schema( a ); + const pack_t schema_b = bli_obj_pack_schema( b ); + + dim_t m = bli_obj_length( c ); + dim_t n = bli_obj_width( c ); + dim_t k = bli_obj_width( a ); + + const void* buf_a = bli_obj_buffer_at_off( a ); + const inc_t cs_a = bli_obj_col_stride( a ); + const dim_t pd_a = bli_obj_panel_dim( a ); + const inc_t ps_a = bli_obj_panel_stride( a ); + + const void* buf_b = bli_obj_buffer_at_off( b ); + const inc_t rs_b = bli_obj_row_stride( b ); + const dim_t pd_b = bli_obj_panel_dim( b ); + const inc_t ps_b = bli_obj_panel_stride( b ); + + void* buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + // Detach and multiply the scalars attached to A and B. + obj_t scalar_a, scalar_b; + bli_obj_scalar_detach( a, &scalar_a ); + bli_obj_scalar_detach( b, &scalar_b ); + bli_mulsc( &scalar_a, &scalar_b ); + + // Grab the addresses of the internal scalar buffers for the scalar + // merged above and the scalar attached to C. + const void* buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); + const void* buf_beta = bli_obj_internal_scalar_buffer( c ); + + // Alias some constants to simpler names. + const dim_t MR = pd_a; + const dim_t NR = pd_b; + const dim_t PACKMR = cs_a; + const dim_t PACKNR = rs_b; + + // Query the context for the micro-kernel address and cast it to its + // function pointer type. + gemm_ukr_vft gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); + + const void* one = bli_obj_buffer_for_const( dt, &BLIS_ONE ); + const char* a_cast = buf_a; + const char* b_cast = buf_b; + char* c_cast = buf_c; + const char* alpha_cast = buf_alpha; + const char* beta_cast = buf_beta; + + /* + Assumptions/assertions: + rs_a == 1 + cs_a == PACKMR + pd_a == MR + ps_a == stride to next micro-panel of A + rs_b == PACKNR + cs_b == 1 + pd_b == NR + ps_b == stride to next micro-panel of B + rs_c == (no assumptions) + cs_c == (no assumptions) + */ + + // Safety trap: Certain indexing within this macro-kernel does not + // work as intended if both MR and NR are odd. + if ( ( bli_is_odd( PACKMR ) && bli_is_odd( NR ) ) || + ( bli_is_odd( PACKNR ) && bli_is_odd( MR ) ) ) bli_abort(); + + // If any dimension is zero, return immediately. + if ( bli_zero_dim3( m, n, k ) ) return; + + // Safeguard: If the current panel of B is entirely above the diagonal, + // it is implicitly zero. So we do nothing. + if ( bli_is_strictly_above_diag_n( diagoffb, k, n ) ) return; + + // If there is a zero region above where the diagonal of B intersects + // the left edge of the panel, adjust the pointer to A and treat this + // case as if the diagonal offset were zero. Note that we don't need to + // adjust the pointer to B since packm would have simply skipped over + // the region that was not stored. + if ( diagoffb < 0 ) + { + k += diagoffb; + a_cast -= diagoffb * PACKMR * dt_size; + diagoffb = 0; + } + + // If there is a zero region to the right of where the diagonal + // of B intersects the bottom of the panel, shrink it to prevent + // "no-op" iterations from executing. + if ( diagoffb + k < n ) + { + n = diagoffb + k; + } + + // Compute number of primary and leftover components of the m and n + // dimensions. + const dim_t n_iter = n / NR + ( n % NR ? 1 : 0 ); + const dim_t n_left = n % NR; + + const dim_t m_iter = m / MR + ( m % MR ? 1 : 0 ); + const dim_t m_left = m % MR; + + // Determine some increments used to step through A, B, and C. + const inc_t rstep_a = ps_a * dt_size; + + const inc_t cstep_b = ps_b * dt_size; + + const inc_t rstep_c = rs_c * MR * dt_size; + const inc_t cstep_c = cs_c * NR * dt_size; + + auxinfo_t aux; + + // Save the pack schemas of A and B to the auxinfo_t object. + bli_auxinfo_set_schema_a( schema_a, &aux ); + bli_auxinfo_set_schema_b( schema_b, &aux ); + + // The 'thread' argument points to the thrinfo_t node for the 2nd (jr) + // loop around the microkernel while the 'caucus' points to the thrinfo_t + // node for the 1st loop (ir). + thrinfo_t* thread = bli_thrinfo_sub_node( thread_par ); + thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); + + // Query the number of threads and thread ids for each loop. + //const dim_t jr_nt = bli_thrinfo_n_way( thread ); + //const dim_t jr_tid = bli_thrinfo_work_id( thread ); + //const dim_t ir_nt = bli_thrinfo_n_way( caucus ); + //const dim_t ir_tid = bli_thrinfo_work_id( caucus ); + + dim_t jr_start, jr_end, jr_inc; + dim_t ir_start, ir_end, ir_inc; + + // Determine the thread range and increment for the 2nd and 1st loops. + // NOTE: The definition of bli_thread_range_jrir() will depend on whether + // slab or round-robin partitioning was requested at configure-time. + // NOTE: Parallelism in the 1st loop is disabled for now. + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); + bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); + + const char* b1 = b_cast; + // char* c1 = c_cast; + + // Loop over the n dimension (NR columns at a time). + for ( dim_t j = jr_start; j < jr_end; j += jr_inc ) + { + const char* a1 = a_cast; + char* c1 = c_cast + j * cstep_c; + char* c11 = c1; + + const doff_t diagoffb_j = diagoffb - ( doff_t )j*NR; + + const dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) + ? NR : n_left ); + + // Determine the offset to the beginning of the panel that + // was packed so we can index into the corresponding location + // in A. Then compute the length of that panel. + const dim_t off_b1121 = bli_max( -diagoffb_j, 0 ); + const dim_t k_b1121 = k - off_b1121; + + // Initialize our next panel of B to be the current panel of B. + const char* b2 = b1; + + // If the current panel of B intersects the diagonal, scale C + // by beta. If it is strictly below the diagonal, scale by one. + // This allows the current macro-kernel to work for both trmm + // and trmm3. + if ( bli_intersects_diag_n( diagoffb_j, k, NR ) ) + { + // Compute the panel stride for the current diagonal- + // intersecting micro-panel. + inc_t ps_b_cur = k_b1121 * PACKNR; + ps_b_cur += ( bli_is_odd( ps_b_cur ) ? 1 : 0 ); + ps_b_cur *= dt_size; + + // Loop over the m dimension (MR rows at a time). + for ( dim_t i = 0; i < m_iter; ++i ) + //for ( dim_t i = ir_start; i < ir_end; i += ir_inc ) + { + const dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) + ? MR : m_left ); + + const char* a1_i = a1 + off_b1121 * PACKMR * dt_size; + + // Compute the addresses of the next panels of A and B. + const char* a2 = bli_trmm_get_next_a_upanel( a1, rstep_a, 1 ); + if ( bli_is_last_iter( i, m_iter, 0, 1 ) ) + { + a2 = a_cast; + b2 = bli_trmm_get_next_b_upanel( b1, cstep_b, jr_inc ); + //if ( bli_is_last_iter( j, n_iter, jr_tid, jr_nt ) ) + // b2 = b_cast; + } + + // Save addresses of next panels of A and B to the auxinfo_t + // object. + bli_auxinfo_set_next_a( a2, &aux ); + bli_auxinfo_set_next_b( b2, &aux ); + + // Invoke the gemm micro-kernel. + gemm_ukr + ( + m_cur, + n_cur, + k_b1121, + ( void* )alpha_cast, + ( void* )a1_i, + ( void* )b1, + ( void* )beta_cast, + c11, rs_c, cs_c, + &aux, + ( cntx_t* )cntx + ); + + a1 += rstep_a; + c11 += rstep_c; + } + + b1 += ps_b_cur; + } + else if ( bli_is_strictly_below_diag_n( diagoffb_j, k, NR ) ) + { + // Loop over the m dimension (MR rows at a time). + for ( dim_t i = 0; i < m_iter; ++i ) + { + const dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) + ? MR : m_left ); + + // Compute the addresses of the next panels of A and B. + const char* a2 = bli_trmm_get_next_a_upanel( a1, rstep_a, 1 ); + if ( bli_is_last_iter( i, m_iter, 0, 1 ) ) + { + a2 = a_cast; + b2 = bli_trmm_get_next_b_upanel( b1, cstep_b, jr_inc ); + //if ( bli_is_last_iter( j, n_iter, jr_tid, jr_nt ) ) + // b2 = b_cast; + } + + // Save addresses of next panels of A and B to the auxinfo_t + // object. + bli_auxinfo_set_next_a( a2, &aux ); + bli_auxinfo_set_next_b( b2, &aux ); + + // Invoke the gemm micro-kernel. + gemm_ukr + ( + m_cur, + n_cur, + k, + ( void* )alpha_cast, + ( void* )a1, + ( void* )b1, + ( void* )one, + c11, rs_c, cs_c, + &aux, + ( cntx_t* )cntx + ); + + a1 += rstep_a; + c11 += rstep_c; + } + + b1 += cstep_b; + } + + //c1 += cstep_c; + } +} + +//PASTEMAC(ch,fprintm)( stdout, "trmm_rl_ker_var2: a1", MR, k_b1121, a1, 1, MR, "%4.1f", "" ); +//PASTEMAC(ch,fprintm)( stdout, "trmm_rl_ker_var2: b1", k_b1121, NR, b1_i, NR, 1, "%4.1f", "" ); + diff --git a/frame/3/trmm/other/bli_trmm_ru_ker_var2.c b/frame/3/trmm/other/bli_trmm_ru_ker_var2.c index 275d6ca470..45af769104 100644 --- a/frame/3/trmm/other/bli_trmm_ru_ker_var2.c +++ b/frame/3/trmm/other/bli_trmm_ru_ker_var2.c @@ -356,7 +356,7 @@ void PASTEMAC(ch,varname) \ b2 = b1; \ \ /* If the current panel of B intersects the diagonal, scale C - by beta. If it is strictly below the diagonal, scale by one. + by beta. If it is strictly above the diagonal, scale by one. This allows the current macro-kernel to work for both trmm and trmm3. */ \ if ( bli_intersects_diag_n( diagoffb_j, k, NR ) ) \ diff --git a/frame/3/trsm/bli_trsm_ll_ker_var2.c b/frame/3/trsm/bli_trsm_ll_ker_var2.c index e2128f1009..786e4f343f 100644 --- a/frame/3/trsm/bli_trsm_ll_ker_var2.c +++ b/frame/3/trsm/bli_trsm_ll_ker_var2.c @@ -37,11 +37,11 @@ void bli_trsm_ll_ker_var2 ( - const obj_t* a, - const obj_t* b, - const obj_t* c, - const cntx_t* cntx, - const cntl_t* cntl, + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntx_t* cntx, + const cntl_t* cntl, thrinfo_t* thread_par ) { @@ -158,47 +158,44 @@ void bli_trsm_ll_ker_var2 // Compute number of primary and leftover components of the m and n // dimensions. - dim_t n_iter = n / NR; - dim_t n_left = n % NR; + const dim_t n_iter = n / NR + ( n % NR ? 1 : 0 ); + const dim_t n_left = n % NR; - dim_t m_iter = m / MR; - dim_t m_left = m % MR; - - if ( n_left ) ++n_iter; - if ( m_left ) ++m_iter; + const dim_t m_iter = m / MR + ( m % MR ? 1 : 0 ); + const dim_t m_left = m % MR; // Determine some increments used to step through A, B, and C. - inc_t rstep_a = ps_a * dt_size; + const inc_t rstep_a = ps_a * dt_size; - inc_t cstep_b = ps_b * dt_size; + const inc_t cstep_b = ps_b * dt_size; - inc_t rstep_c = rs_c * MR * dt_size; - inc_t cstep_c = cs_c * NR * dt_size; + const inc_t rstep_c = rs_c * MR * dt_size; + const inc_t cstep_c = cs_c * NR * dt_size; - // Save the pack schemas of A and B to the auxinfo_t object. auxinfo_t aux; + + // Save the pack schemas of A and B to the auxinfo_t object. bli_auxinfo_set_schema_a( schema_a, &aux ); bli_auxinfo_set_schema_b( schema_b, &aux ); // We don't bother querying the thrinfo_t node for the 1st loop because // we can't parallelize that loop in trsm due to the inter-iteration // dependencies that exist. + thrinfo_t* thread = bli_thrinfo_sub_node( thread_par ); //thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); // Query the number of threads and thread ids for each loop. - thrinfo_t* thread = bli_thrinfo_sub_node( thread_par ); - dim_t jr_nt = bli_thrinfo_n_way( thread ); - dim_t jr_tid = bli_thrinfo_work_id( thread ); + const dim_t jr_nt = bli_thrinfo_n_way( thread ); + const dim_t jr_tid = bli_thrinfo_work_id( thread ); - dim_t jr_start, jr_end; - dim_t jr_inc; + dim_t jr_start, jr_end, jr_inc; // Determine the thread range and increment for the 2nd loop. - // NOTE: The definition of bli_thread_range_jrir() will depend on whether + // NOTE: The definition of bli_thread_range_slrr() will depend on whether // slab or round-robin partitioning was requested at configure-time. // NOTE: Parallelism in the 1st loop is unattainable due to the // inter-iteration dependencies present in trsm. - bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); + bli_thread_range_slrr( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); // Loop over the n dimension (NR columns at a time). for ( dim_t j = jr_start; j < jr_end; j += jr_inc ) @@ -206,7 +203,8 @@ void bli_trsm_ll_ker_var2 const char* b1 = b_cast + j * cstep_b; char* c1 = c_cast + j * cstep_c; - dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); + const dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) + ? NR : n_left ); // Initialize our next panel of B to be the current panel of B. const char* b2 = b1; @@ -217,9 +215,10 @@ void bli_trsm_ll_ker_var2 // Loop over the m dimension (MR rows at a time). for ( dim_t i = 0; i < m_iter; ++i ) { - doff_t diagoffa_i = diagoffa + ( doff_t )i*MR; + const doff_t diagoffa_i = diagoffa + ( doff_t )i*MR; - dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); + const dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) + ? MR : m_left ); // If the current panel of A intersects the diagonal, use a // special micro-kernel that performs a fused gemm and trsm. @@ -230,10 +229,10 @@ void bli_trsm_ll_ker_var2 if ( bli_intersects_diag_n( diagoffa_i, MR, k ) ) { // Compute various offsets into and lengths of parts of A. - dim_t off_a10 = 0; - dim_t k_a1011 = diagoffa_i + MR; - dim_t k_a10 = k_a1011 - MR; - dim_t off_a11 = k_a10; + const dim_t off_a10 = 0; + const dim_t k_a1011 = diagoffa_i + MR; + const dim_t k_a10 = k_a1011 - MR; + const dim_t off_a11 = k_a10; // Compute the panel stride for the current diagonal- // intersecting micro-panel. @@ -258,7 +257,7 @@ void bli_trsm_ll_ker_var2 { a2 = a_cast; b2 = b1; - if ( bli_is_last_iter( j, n_iter, jr_tid, jr_nt ) ) + if ( bli_is_last_iter_slrr( j, n_iter, jr_tid, jr_nt ) ) b2 = b_cast; } @@ -292,7 +291,7 @@ void bli_trsm_ll_ker_var2 { a2 = a_cast; b2 = b1; - if ( bli_is_last_iter( j, n_iter, jr_tid, jr_nt ) ) + if ( bli_is_last_iter_slrr( j, n_iter, jr_tid, jr_nt ) ) b2 = b_cast; } diff --git a/frame/3/trsm/bli_trsm_lu_ker_var2.c b/frame/3/trsm/bli_trsm_lu_ker_var2.c index 314ee30706..ebf44905b4 100644 --- a/frame/3/trsm/bli_trsm_lu_ker_var2.c +++ b/frame/3/trsm/bli_trsm_lu_ker_var2.c @@ -37,11 +37,11 @@ void bli_trsm_lu_ker_var2 ( - const obj_t* a, - const obj_t* b, - const obj_t* c, - const cntx_t* cntx, - const cntl_t* cntl, + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntx_t* cntx, + const cntl_t* cntl, thrinfo_t* thread_par ) { @@ -169,47 +169,44 @@ void bli_trsm_lu_ker_var2 // Compute number of primary and leftover components of the m and n // dimensions. - dim_t n_iter = n / NR; - dim_t n_left = n % NR; + const dim_t n_iter = n / NR + ( n % NR ? 1 : 0 ); + const dim_t n_left = n % NR; - dim_t m_iter = m / MR; - dim_t m_left = m % MR; - - if ( n_left ) ++n_iter; - if ( m_left ) ++m_iter; + const dim_t m_iter = m / MR + ( m % MR ? 1 : 0 ); + const dim_t m_left = m % MR; // Determine some increments used to step through A, B, and C. - inc_t rstep_a = ps_a * dt_size; + const inc_t rstep_a = ps_a * dt_size; - inc_t cstep_b = ps_b * dt_size; + const inc_t cstep_b = ps_b * dt_size; - inc_t rstep_c = rs_c * MR * dt_size; - inc_t cstep_c = cs_c * NR * dt_size; + const inc_t rstep_c = rs_c * MR * dt_size; + const inc_t cstep_c = cs_c * NR * dt_size; - // Save the pack schemas of A and B to the auxinfo_t object. auxinfo_t aux; + + // Save the pack schemas of A and B to the auxinfo_t object. bli_auxinfo_set_schema_a( schema_a, &aux ); bli_auxinfo_set_schema_b( schema_b, &aux ); // We don't bother querying the thrinfo_t node for the 1st loop because // we can't parallelize that loop in trsm due to the inter-iteration // dependencies that exist. + thrinfo_t* thread = bli_thrinfo_sub_node( thread_par ); //thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); // Query the number of threads and thread ids for each loop. - thrinfo_t* thread = bli_thrinfo_sub_node( thread_par ); - dim_t jr_nt = bli_thrinfo_n_way( thread ); - dim_t jr_tid = bli_thrinfo_work_id( thread ); + const dim_t jr_nt = bli_thrinfo_n_way( thread ); + const dim_t jr_tid = bli_thrinfo_work_id( thread ); - dim_t jr_start, jr_end; - dim_t jr_inc; + dim_t jr_start, jr_end, jr_inc; // Determine the thread range and increment for the 2nd loop. - // NOTE: The definition of bli_thread_range_jrir() will depend on whether + // NOTE: The definition of bli_thread_range_slrr() will depend on whether // slab or round-robin partitioning was requested at configure-time. // NOTE: Parallelism in the 1st loop is unattainable due to the // inter-iteration dependencies present in trsm. - bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); + bli_thread_range_slrr( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); // Loop over the n dimension (NR columns at a time). for ( dim_t j = jr_start; j < jr_end; j += jr_inc ) @@ -217,7 +214,8 @@ void bli_trsm_lu_ker_var2 const char* b1 = b_cast + j * cstep_b; char* c1 = c_cast + j * cstep_c; - dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); + const dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) + ? NR : n_left ); // Initialize our next panel of B to be the current panel of B. const char* b2 = b1; @@ -228,10 +226,11 @@ void bli_trsm_lu_ker_var2 // Loop over the m dimension (MR rows at a time). for ( dim_t ib = 0; ib < m_iter; ++ib ) { - dim_t i = m_iter - 1 - ib; - doff_t diagoffa_i = diagoffa + ( doff_t )i*MR; + const dim_t i = m_iter - 1 - ib; + const doff_t diagoffa_i = diagoffa + ( doff_t )i*MR; - dim_t m_cur = ( bli_is_not_edge_b( ib, m_iter, m_left ) ? MR : m_left ); + const dim_t m_cur = ( bli_is_not_edge_b( ib, m_iter, m_left ) + ? MR : m_left ); // If the current panel of A intersects the diagonal, use a // special micro-kernel that performs a fused gemm and trsm. @@ -242,11 +241,11 @@ void bli_trsm_lu_ker_var2 if ( bli_intersects_diag_n( diagoffa_i, MR, k ) ) { // Compute various offsets into and lengths of parts of A. - dim_t off_a11 = diagoffa_i; - dim_t k_a1112 = k - off_a11;; - dim_t k_a11 = MR; - dim_t k_a12 = k_a1112 - MR; - dim_t off_a12 = off_a11 + k_a11; + const dim_t off_a11 = diagoffa_i; + const dim_t k_a1112 = k - off_a11;; + const dim_t k_a11 = MR; + const dim_t k_a12 = k_a1112 - MR; + const dim_t off_a12 = off_a11 + k_a11; // Compute the panel stride for the current diagonal- // intersecting micro-panel. @@ -271,7 +270,7 @@ void bli_trsm_lu_ker_var2 { a2 = a_cast; b2 = b1; - if ( bli_is_last_iter( j, n_iter, jr_tid, jr_nt ) ) + if ( bli_is_last_iter_slrr( j, n_iter, jr_tid, jr_nt ) ) b2 = b_cast; } @@ -305,7 +304,7 @@ void bli_trsm_lu_ker_var2 { a2 = a_cast; b2 = b1; - if ( bli_is_last_iter( j, n_iter, jr_tid, jr_nt ) ) + if ( bli_is_last_iter_slrr( j, n_iter, jr_tid, jr_nt ) ) b2 = b_cast; } diff --git a/frame/3/trsm/bli_trsm_rl_ker_var2.c b/frame/3/trsm/bli_trsm_rl_ker_var2.c index 42e72840ef..073fe3ec07 100644 --- a/frame/3/trsm/bli_trsm_rl_ker_var2.c +++ b/frame/3/trsm/bli_trsm_rl_ker_var2.c @@ -37,11 +37,11 @@ void bli_trsm_rl_ker_var2 ( - const obj_t* a, - const obj_t* b, - const obj_t* c, - const cntx_t* cntx, - const cntl_t* cntl, + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntx_t* cntx, + const cntl_t* cntl, thrinfo_t* thread_par ) { @@ -131,23 +131,23 @@ void bli_trsm_rl_ker_var2 the right-hand side parameter case). */ - /* Safety trap: Certain indexing within this macro-kernel does not - work as intended if both MR and NR are odd. */ + // Safety trap: Certain indexing within this macro-kernel does not + // work as intended if both MR and NR are odd. if ( ( bli_is_odd( PACKMR ) && bli_is_odd( NR ) ) || ( bli_is_odd( PACKNR ) && bli_is_odd( MR ) ) ) bli_abort(); - /* If any dimension is zero, return immediately. */ + // If any dimension is zero, return immediately. if ( bli_zero_dim3( m, n, k ) ) return; - /* Safeguard: If the current panel of B is entirely above its diagonal, - it is implicitly zero. So we do nothing. */ + // Safeguard: If the current panel of B is entirely above its diagonal, + // it is implicitly zero. So we do nothing. if ( bli_is_strictly_above_diag_n( diagoffb, k, n ) ) return; - /* If there is a zero region above where the diagonal of B intersects - the left edge of the panel, adjust the pointer to A and treat this - case as if the diagonal offset were zero. Note that we don't need to - adjust the pointer to B since packm would have simply skipped over - the region that was not stored. */ + // If there is a zero region above where the diagonal of B intersects + // the left edge of the panel, adjust the pointer to A and treat this + // case as if the diagonal offset were zero. Note that we don't need to + // adjust the pointer to B since packm would have simply skipped over + // the region that was not stored. if ( diagoffb < 0 ) { k += diagoffb; @@ -155,40 +155,40 @@ void bli_trsm_rl_ker_var2 diagoffb = 0; } - /* If there is a zero region to the right of where the diagonal - of B intersects the bottom of the panel, shrink it so that - we can index to the correct place in C (corresponding to the - part of the panel of B that was packed). - NOTE: This is NOT being done to skip over "no-op" iterations, - as with the trsm_lu macro-kernel. This MUST be done for correct - execution because we use n (via n_iter) to compute diagonal and - index offsets for backwards movement through B. */ + // If there is a zero region to the right of where the diagonal + // of B intersects the bottom of the panel, shrink it so that + // we can index to the correct place in C (corresponding to the + // part of the panel of B that was packed). + // NOTE: This is NOT being done to skip over "no-op" iterations, + // as with the trsm_lu macro-kernel. This MUST be done for correct + // execution because we use n (via n_iter) to compute diagonal and + // index offsets for backwards movement through B. if ( diagoffb + k < n ) { n = diagoffb + k; } - /* Check the k dimension, which needs to be a multiple of NR. If k - isn't a multiple of NR, we adjust it higher to satisfy the micro- - kernel, which is expecting to perform an NR x NR triangular solve. - This adjustment of k is consistent with what happened when B was - packed: all of its bottom/right edges were zero-padded, and - furthermore, the panel that stores the bottom-right corner of the - matrix has its diagonal extended into the zero-padded region (as - identity). This allows the trsm of that bottom-right panel to - proceed without producing any infs or NaNs that would infect the - "good" values of the corresponding block of A. */ + // Check the k dimension, which needs to be a multiple of NR. If k + // isn't a multiple of NR, we adjust it higher to satisfy the micro- + // kernel, which is expecting to perform an NR x NR triangular solve. + // This adjustment of k is consistent with what happened when B was + // packed: all of its bottom/right edges were zero-padded, and + // furthermore, the panel that stores the bottom-right corner of the + // matrix has its diagonal extended into the zero-padded region (as + // identity). This allows the trsm of that bottom-right panel to + // proceed without producing any infs or NaNs that would infect the + // "good" values of the corresponding block of A. if ( k % NR != 0 ) k += NR - ( k % NR ); - /* NOTE: We don't need to check that n is a multiple of PACKNR since we - know that the underlying buffer was already allocated to have an n - dimension that is a multiple of PACKNR, with the region between the - last column and the next multiple of NR zero-padded accordingly. */ + // NOTE: We don't need to check that n is a multiple of PACKNR since we + // know that the underlying buffer was already allocated to have an n + // dimension that is a multiple of PACKNR, with the region between the + // last column and the next multiple of NR zero-padded accordingly. thrinfo_t* thread = bli_thrinfo_sub_node( thread_par ); - /* Compute number of primary and leftover components of the m and n - dimensions. */ + // Compute number of primary and leftover components of the m and n + // dimensions. dim_t n_iter = n / NR; dim_t n_left = n % NR; @@ -198,7 +198,7 @@ void bli_trsm_rl_ker_var2 if ( n_left ) ++n_iter; if ( m_left ) ++m_iter; - /* Determine some increments used to step through A, B, and C. */ + // Determine some increments used to step through A, B, and C. inc_t rstep_a = ps_a * dt_size; inc_t cstep_b = ps_b * dt_size; @@ -206,17 +206,18 @@ void bli_trsm_rl_ker_var2 inc_t rstep_c = rs_c * MR * dt_size; inc_t cstep_c = cs_c * NR * dt_size; - /* Save the pack schemas of A and B to the auxinfo_t object. - NOTE: We swap the values for A and B since the triangular - "A" matrix is actually contained within B. */ auxinfo_t aux; + + // Save the pack schemas of A and B to the auxinfo_t object. + // NOTE: We swap the values for A and B since the triangular + // "A" matrix is actually contained within B. bli_auxinfo_set_schema_a( schema_b, &aux ); bli_auxinfo_set_schema_b( schema_a, &aux ); const char* b1 = b_cast; char* c1 = c_cast; - /* Loop over the n dimension (NR columns at a time). */ + // Loop over the n dimension (NR columns at a time). for ( dim_t jb = 0; jb < n_iter; ++jb ) { dim_t j = n_iter - 1 - jb; @@ -227,50 +228,50 @@ void bli_trsm_rl_ker_var2 const char* a1 = a_cast; char* c11 = c1 + (n_iter-1)*cstep_c; - /* Initialize our next panel of B to be the current panel of B. */ + // Initialize our next panel of B to be the current panel of B. const char* b2 = b1; - /* If the current panel of B intersects the diagonal, use a - special micro-kernel that performs a fused gemm and trsm. - If the current panel of B resides below the diagonal, use a - a regular gemm micro-kernel. Otherwise, if it is above the - diagonal, it was not packed (because it is implicitly zero) - and so we do nothing. */ + // If the current panel of B intersects the diagonal, use a + // special micro-kernel that performs a fused gemm and trsm. + // If the current panel of B resides below the diagonal, use a + // a regular gemm micro-kernel. Otherwise, if it is above the + // diagonal, it was not packed (because it is implicitly zero) + // and so we do nothing. if ( bli_intersects_diag_n( diagoffb_j, k, NR ) ) { - /* Determine the offset to and length of the panel that was packed - so we can index into the corresponding location in A. */ + // Determine the offset to and length of the panel that was packed + // so we can index into the corresponding location in A. dim_t off_b11 = bli_max( -diagoffb_j, 0 ); dim_t k_b1121 = k - off_b11; dim_t k_b11 = NR; dim_t k_b21 = k_b1121 - NR; dim_t off_b21 = off_b11 + k_b11; - /* Compute the addresses of the triangular block B11 and the - panel B21. */ + // Compute the addresses of the triangular block B11 and the + // panel B21. const char* b11 = b1; const char* b21 = b1 + k_b11 * PACKNR * dt_size; - /*b21 = bli_ptr_inc_by_frac( b1, sizeof( ctype ), k_b11 * PACKNR, 1 );*/ + //b21 = bli_ptr_inc_by_frac( b1, sizeof( ctype ), k_b11 * PACKNR, 1 ); - /* Compute the panel stride for the current micro-panel. */ + // Compute the panel stride for the current micro-panel. inc_t ps_b_cur = k_b1121 * PACKNR; ps_b_cur += ( bli_is_odd( ps_b_cur ) ? 1 : 0 ); ps_b_cur *= dt_size; - /* Loop over the m dimension (MR rows at a time). */ + // Loop over the m dimension (MR rows at a time). for ( dim_t i = 0; i < m_iter; ++i ) { if ( bli_trsm_my_iter_rr( i, thread ) ){ dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); - /* Compute the addresses of the A11 block and A12 panel. */ + // Compute the addresses of the A11 block and A12 panel. const char* a11 = a1 + off_b11 * PACKMR * dt_size; const char* a12 = a1 + off_b21 * PACKMR * dt_size; - /* Compute the addresses of the next panels of A and B. */ + // Compute the addresses of the next panels of A and B. const char* a2 = a1; - /*if ( bli_is_last_iter_rr( i, m_iter, 0, 1 ) ) */ + //if ( bli_is_last_iter_rr( i, m_iter, 0, 1 ) ) if ( i + bli_thrinfo_num_threads(thread) >= m_iter ) { a2 = a_cast; @@ -279,9 +280,9 @@ void bli_trsm_rl_ker_var2 b2 = b_cast; } - /* Save addresses of next panels of A and B to the auxinfo_t - object. NOTE: We swap the values for A and B since the - triangular "A" matrix is actually contained within B. */ + // Save addresses of next panels of A and B to the auxinfo_t + // object. NOTE: We swap the values for A and B since the + // triangular "A" matrix is actually contained within B. bli_auxinfo_set_next_a( b2, &aux ); bli_auxinfo_set_next_b( a2, &aux ); @@ -310,16 +311,16 @@ void bli_trsm_rl_ker_var2 } else if ( bli_is_strictly_below_diag_n( diagoffb_j, k, NR ) ) { - /* Loop over the m dimension (MR rows at a time). */ + // Loop over the m dimension (MR rows at a time). for ( dim_t i = 0; i < m_iter; ++i ) { if ( bli_trsm_my_iter_rr( i, thread ) ){ dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); - /* Compute the addresses of the next panels of A and B. */ + // Compute the addresses of the next panels of A and B. const char* a2 = a1; - /*if ( bli_is_last_iter_rr( i, m_iter, 0, 1 ) ) */ + //if ( bli_is_last_iter_rr( i, m_iter, 0, 1 ) ) if ( i + bli_thrinfo_num_threads(thread) >= m_iter ) { a2 = a_cast; @@ -328,13 +329,13 @@ void bli_trsm_rl_ker_var2 b2 = b_cast; } - /* Save addresses of next panels of A and B to the auxinfo_t - object. NOTE: We swap the values for A and B since the - triangular "A" matrix is actually contained within B. */ + // Save addresses of next panels of A and B to the auxinfo_t + // object. NOTE: We swap the values for A and B since the + // triangular "A" matrix is actually contained within B. bli_auxinfo_set_next_a( b2, &aux ); bli_auxinfo_set_next_b( a2, &aux ); - /* Invoke the gemm micro-kernel. */ + // Invoke the gemm micro-kernel. gemm_ukr ( m_cur, diff --git a/frame/3/trsm/bli_trsm_ru_ker_var2.c b/frame/3/trsm/bli_trsm_ru_ker_var2.c index 6cc9a8bbb2..a05e944941 100644 --- a/frame/3/trsm/bli_trsm_ru_ker_var2.c +++ b/frame/3/trsm/bli_trsm_ru_ker_var2.c @@ -37,11 +37,11 @@ void bli_trsm_ru_ker_var2 ( - const obj_t* a, - const obj_t* b, - const obj_t* c, - const cntx_t* cntx, - const cntl_t* cntl, + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntx_t* cntx, + const cntl_t* cntl, thrinfo_t* thread_par ) { @@ -244,7 +244,7 @@ void bli_trsm_ru_ker_var2 // block B11. const char* b01 = b1; const char* b11 = b1 + k_b01 * PACKNR * dt_size; - //b11 = bli_ptr_inc_by_frac( b1, sizeof( ctype ), k_b01 * PACKNR, 1 );*/ + //b11 = bli_ptr_inc_by_frac( b1, sizeof( ctype ), k_b01 * PACKNR, 1 ); // Compute the panel stride for the current micro-panel. inc_t ps_b_cur = k_b0111 * PACKNR; diff --git a/frame/3/trsm/bli_trsm_var.h b/frame/3/trsm/bli_trsm_var.h index a498e687e3..4d7e72b436 100644 --- a/frame/3/trsm/bli_trsm_var.h +++ b/frame/3/trsm/bli_trsm_var.h @@ -48,7 +48,7 @@ void PASTEMAC0(opname) \ const obj_t* c, \ const cntx_t* cntx, \ const cntl_t* cntl, \ - thrinfo_t* thread \ + thrinfo_t* thread_par \ ); GENPROT( trsm_blk_var1 ) diff --git a/frame/3/trsm/bli_trsm_xx_ker_var2.c b/frame/3/trsm/bli_trsm_xx_ker_var2.c index 39c5372f3e..dfeefcd9d9 100644 --- a/frame/3/trsm/bli_trsm_xx_ker_var2.c +++ b/frame/3/trsm/bli_trsm_xx_ker_var2.c @@ -43,12 +43,12 @@ static l3_var_oft vars[2][2] = void bli_trsm_xx_ker_var2 ( - const obj_t* a, - const obj_t* b, - const obj_t* c, - const cntx_t* cntx, - const cntl_t* cntl, - thrinfo_t* thread + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntx_t* cntx, + const cntl_t* cntl, + thrinfo_t* thread_par ) { dim_t side; @@ -81,7 +81,7 @@ void bli_trsm_xx_ker_var2 c, cntx, cntl, - thread + thread_par ); } diff --git a/frame/base/bli_info.c b/frame/base/bli_info.c index 1f00537d59..3fc76b978e 100644 --- a/frame/base/bli_info.c +++ b/frame/base/bli_info.c @@ -156,7 +156,7 @@ gint_t bli_info_get_enable_hpx_as_default( void ) return 0; #endif } -gint_t bli_info_get_thread_part_jrir_slab( void ) +gint_t bli_info_get_thread_jrir_slab( void ) { #ifdef BLIS_ENABLE_JRIR_SLAB return 1; @@ -164,7 +164,7 @@ gint_t bli_info_get_thread_part_jrir_slab( void ) return 0; #endif } -gint_t bli_info_get_thread_part_jrir_rr( void ) +gint_t bli_info_get_thread_jrir_rr( void ) { #ifdef BLIS_ENABLE_JRIR_RR return 1; @@ -172,6 +172,14 @@ gint_t bli_info_get_thread_part_jrir_rr( void ) return 0; #endif } +gint_t bli_info_get_thread_jrir_tlb( void ) +{ +#ifdef BLIS_ENABLE_JRIR_TLB + return 1; +#else + return 0; +#endif +} gint_t bli_info_get_enable_memkind( void ) { #ifdef BLIS_ENABLE_MEMKIND diff --git a/frame/base/bli_info.h b/frame/base/bli_info.h index 08a99daea9..300b3f5843 100644 --- a/frame/base/bli_info.h +++ b/frame/base/bli_info.h @@ -74,8 +74,9 @@ BLIS_EXPORT_BLIS gint_t bli_info_get_enable_hpx( void ); BLIS_EXPORT_BLIS gint_t bli_info_get_enable_openmp_as_default( void ); BLIS_EXPORT_BLIS gint_t bli_info_get_enable_pthreads_as_default( void ); BLIS_EXPORT_BLIS gint_t bli_info_get_enable_hpx_as_default( void ); -BLIS_EXPORT_BLIS gint_t bli_info_get_thread_part_jrir_slab( void ); -BLIS_EXPORT_BLIS gint_t bli_info_get_thread_part_jrir_rr( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_thread_jrir_slab( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_thread_jrir_rr( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_thread_jrir_tlb( void ); BLIS_EXPORT_BLIS gint_t bli_info_get_enable_memkind( void ); BLIS_EXPORT_BLIS gint_t bli_info_get_enable_sandbox( void ); diff --git a/frame/base/bli_prune.c b/frame/base/bli_prune.c index ebe5c23653..31c3d86d22 100644 --- a/frame/base/bli_prune.c +++ b/frame/base/bli_prune.c @@ -38,9 +38,28 @@ void bli_prune_unref_mparts( obj_t* p, mdim_t mdim_p, obj_t* s, mdim_t mdim_s ) { - // If the primary object is general, it has no structure, and + // NOTE: This function is not safe to use on packed objects because it does + // not currently take into account the atomicity of the packed micropanel + // widths (i.e., the register blocksize). That is, this function will prune + // greedily, without regard to whether doing so would prune off part of a + // micropanel *which has already been packed* and "assigned" to a thread for + // inclusion in the computation. In order to be safe for use use on packed + // matrices, this function would need to prune only up to the nearest + // micropanel edge (and to the corresponding location within the secondary + // matrix), which may not coincide exactly with the diagonal offset. + if ( bli_obj_is_packed( p ) || bli_obj_is_packed( s ) ) bli_abort(); + + // If the primary object is general AND dense, it has no structure, and // therefore, no unreferenced parts. - if ( bli_obj_is_general( p ) ) return; + // NOTE: There is at least one situation where the matrix is general but + // its uplo_t value is lower or upper: gemmt. This operation benefits from + // pruning unreferenced regions the same way herk/her2k/syrk/syr2k would. + // Because of gemmt, and any future similar operations, we limit early + // returns to situations where the primary object has a dense uplo_t value + // IN ADDITION TO general structure (rather than only checking for general + // structure). + if ( bli_obj_is_general( p ) && + bli_obj_is_dense( p ) ) return; // If the primary object is BLIS_ZEROS, set the dimensions so that the // matrix is empty. This is not strictly needed but rather a minor @@ -116,21 +135,13 @@ void bli_prune_unref_mparts( obj_t* p, mdim_t mdim_p, if ( bli_is_m_dim( mdim_p ) ) q = m; else /* if ( bli_is_n_dim( mdim_p ) ) */ q = n; - // Update the affected objects in case anything changed. Notice that - // it is okay to update the dimension and diagonal offset fields of - // packed primary objects, as long as we do so in tandem with the - // secondary object to maintain conformality. This just means that - // the "ignore-able" zero region is skipped over here, rather than - // within the macro-kernel. + // Update the affected objects' diagonal offset, dimensions, and row + // and column offsets, in case anything changed. bli_obj_set_diag_offset( diagoff_p, p ); bli_obj_set_dim( mdim_p, q, p ); bli_obj_set_dim( mdim_s, q, s ); - - // Only update the affected offset fields if the object in question - // is NOT a packed object. Otherwise, bli_obj_buffer_at_off() will - // compute the wrong address within the macro-kernel object wrapper. - if ( !bli_obj_is_packed( p ) ) { bli_obj_inc_off( mdim_p, off_inc, p ); } - if ( !bli_obj_is_packed( s ) ) { bli_obj_inc_off( mdim_s, off_inc, s ); } + bli_obj_inc_off( mdim_p, off_inc, p ); + bli_obj_inc_off( mdim_s, off_inc, s ); } } diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index 786998f23d..64124c6823 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -143,12 +143,44 @@ void bli_rntm_set_ways_for_op // kind of information is already stored in the rntm_t object. bli_rntm_factorize( m, n, k, rntm ); -#if 0 -printf( "bli_rntm_set_ways_for_op()\n" ); -bli_rntm_print( rntm ); -#endif + #if 0 + printf( "bli_rntm_set_ways_for_op()\n" ); + bli_rntm_print( rntm ); + #endif // Now modify the number of ways, if necessary, based on the operation. + + // Consider gemm (hemm, symm), gemmt (herk, her2k, syrk, syr2k), and + // trmm (trmm, trmm3). + if ( +#ifdef BLIS_ENABLE_JRIR_TLB + l3_op == BLIS_GEMM || + l3_op == BLIS_GEMMT || + l3_op == BLIS_TRMM || +#endif + FALSE + ) + { + dim_t jc = bli_rntm_jc_ways( rntm ); + dim_t pc = bli_rntm_pc_ways( rntm ); + dim_t ic = bli_rntm_ic_ways( rntm ); + dim_t jr = bli_rntm_jr_ways( rntm ); + dim_t ir = bli_rntm_ir_ways( rntm ); + + // If TLB is enabled for gemm or gemmt, redirect any ir loop parallelism + // into the jr loop. + bli_rntm_set_ways_only + ( + jc, + pc, + ic, + jr * ir, + 1, + rntm + ); + } + + // Consider trmm, trmm3, trsm. if ( l3_op == BLIS_TRMM || l3_op == BLIS_TRSM ) { diff --git a/frame/include/bli_config_macro_defs.h b/frame/include/bli_config_macro_defs.h index 8719b0260b..6bd9b5c4c9 100644 --- a/frame/include/bli_config_macro_defs.h +++ b/frame/include/bli_config_macro_defs.h @@ -36,6 +36,16 @@ #ifndef BLIS_CONFIG_MACRO_DEFS_H #define BLIS_CONFIG_MACRO_DEFS_H +// NOTE: This file should ONLY contain processing of macros that are set by +// configure and output into bli_config.h. Any other macro processing -- +// especially such as for those macros that are expected to be optionally +// set within a configuration's bli_family_.h header -- MUST be placed +// in bli_kernel_macro_defs.h instead. The reason: bli_arch_config.h (which +// #includes the configuration's bli_family_.h header) is #included +// much later in blis.h than this file (bli_config_macro_defs.h), and so any +// macros set in bli_family_.h would have no effect on the processing +// that happens below. + // -- INTEGER PROPERTIES ------------------------------------------------------- diff --git a/frame/include/bli_kernel_macro_defs.h b/frame/include/bli_kernel_macro_defs.h index d273c353ab..8c0f1cb143 100644 --- a/frame/include/bli_kernel_macro_defs.h +++ b/frame/include/bli_kernel_macro_defs.h @@ -151,6 +151,7 @@ #define BLIS_FREE_USER free #endif + // -- Other system-related definitions ----------------------------------------- // Size of a virtual memory page. This is used to align blocks within the @@ -245,6 +246,7 @@ #define BLIS_POOL_ADDR_OFFSET_SIZE_GEN 0 #endif + // -- MR and NR blocksizes (only for reference kernels) ------------------------ // The build system defines BLIS_IN_REF_KERNEL, but only when compiling diff --git a/frame/include/bli_param_macro_defs.h b/frame/include/bli_param_macro_defs.h index 1822065dab..0865b11e99 100644 --- a/frame/include/bli_param_macro_defs.h +++ b/frame/include/bli_param_macro_defs.h @@ -927,7 +927,6 @@ BLIS_INLINE stor3_t bli_stor3_transb( stor3_t id ) } - // index-related BLIS_INLINE bool bli_is_edge_f( dim_t i, dim_t n_iter, dim_t n_left ) @@ -954,7 +953,7 @@ BLIS_INLINE bool bli_is_not_edge_b( dim_t i, dim_t n_iter, dim_t n_left ) ( i != 0 || n_left == 0 ); } -BLIS_INLINE bool bli_is_last_iter_sl( dim_t i, dim_t end_iter, dim_t tid, dim_t nth ) +BLIS_INLINE bool bli_is_last_iter_sl( dim_t i, dim_t end_iter ) { return ( bool ) ( i == end_iter - 1 ); @@ -966,15 +965,59 @@ BLIS_INLINE bool bli_is_last_iter_rr( dim_t i, dim_t end_iter, dim_t tid, dim_t ( i == end_iter - 1 - ( ( end_iter - tid - 1 ) % nth ) ); } -BLIS_INLINE bool bli_is_last_iter( dim_t i, dim_t end_iter, dim_t tid, dim_t nth ) +BLIS_INLINE bool bli_is_last_iter_slrr( dim_t i, dim_t end_iter, dim_t tid, dim_t nth ) { #ifdef BLIS_ENABLE_JRIR_SLAB - return bli_is_last_iter_sl( i, end_iter, tid, nth ); + return bli_is_last_iter_sl( i, end_iter ); #else // BLIS_ENABLE_JRIR_RR return bli_is_last_iter_rr( i, end_iter, tid, nth ); #endif } +BLIS_INLINE bool bli_is_last_iter_l( dim_t i, dim_t end_iter, dim_t tid, dim_t nth ) +{ + return bli_is_last_iter_slrr( i, end_iter, tid, nth ); +} + +BLIS_INLINE bool bli_is_last_iter_u( doff_t diagoff, dim_t mr, dim_t nr, inc_t inc ) +{ + return bli_is_strictly_below_diag_n( diagoff + inc*mr, mr, nr ); +} + +BLIS_INLINE bool bli_is_last_iter_tlb_l( dim_t i, dim_t end_iter ) +{ + return bli_is_last_iter_sl( i, end_iter ); +} + +BLIS_INLINE bool bli_is_last_iter_tlb_u( doff_t diagoff, dim_t mr, dim_t nr ) +{ + return bli_is_strictly_below_diag_n( diagoff + 1*mr, mr, nr ); +} + +BLIS_INLINE bool bli_is_my_iter_sl( dim_t i, dim_t st, dim_t en ) +{ + return ( st <= i && i < en ); +} + +BLIS_INLINE bool bli_is_my_iter_rr( dim_t i, dim_t work_id, dim_t n_way ) +{ + return ( i % n_way == work_id % n_way ); +} + +BLIS_INLINE bool bli_is_my_iter( dim_t i, dim_t st, dim_t en, dim_t work_id, dim_t n_way ) +{ + // NOTE: This function is (as of this writing) only called from packm. + // If the structure of the cpp macros below is ever changed, make sure + // it is still consistent with that of bli_thread_range_slrr() since + // these functions are used together in packm. + +#ifdef BLIS_ENABLE_JRIR_RR + return bli_is_my_iter_rr( i, work_id, n_way ); +#else // ifdef ( _SLAB || _TLB ) + return bli_is_my_iter_sl( i, st, en ); +#endif +} + // packbuf_t-related diff --git a/frame/include/blis.h b/frame/include/blis.h index e75c3006e1..144ed1f431 100644 --- a/frame/include/blis.h +++ b/frame/include/blis.h @@ -80,6 +80,21 @@ extern "C" { #include "bli_pragma_macro_defs.h" +// -- Threading definitions -- + +#include "bli_thread.h" +#include "bli_thread_range.h" +#include "bli_thread_range_slab_rr.h" +#include "bli_thread_range_tlb.h" + +#include "bli_pthread.h" + + +// -- Constant definitions -- + +#include "bli_extern_defs.h" + + // -- BLIS architecture/kernel definitions -- #include "bli_l1v_ker_prot.h" diff --git a/frame/thread/bli_thread.c b/frame/thread/bli_thread.c index 4cba76b207..d41f370539 100644 --- a/frame/thread/bli_thread.c +++ b/frame/thread/bli_thread.c @@ -108,907 +108,6 @@ void bli_thread_launch // ----------------------------------------------------------------------------- -void bli_thread_range_sub - ( - const thrinfo_t* thread, - dim_t n, - dim_t bf, - bool handle_edge_low, - dim_t* start, - dim_t* end - ) -{ - dim_t n_way = bli_thrinfo_n_way( thread ); - - if ( n_way == 1 ) { *start = 0; *end = n; return; } - - dim_t work_id = bli_thrinfo_work_id( thread ); - - dim_t all_start = 0; - dim_t all_end = n; - - dim_t size = all_end - all_start; - - dim_t n_bf_whole = size / bf; - dim_t n_bf_left = size % bf; - - dim_t n_bf_lo = n_bf_whole / n_way; - dim_t n_bf_hi = n_bf_whole / n_way; - - // In this function, we partition the space between all_start and - // all_end into n_way partitions, each a multiple of block_factor - // with the exception of the one partition that recieves the - // "edge" case (if applicable). - // - // Here are examples of various thread partitionings, in units of - // the block_factor, when n_way = 4. (A '+' indicates the thread - // that receives the leftover edge case (ie: n_bf_left extra - // rows/columns in its sub-range). - // (all_start ... all_end) - // n_bf_whole _left hel n_th_lo _hi thr0 thr1 thr2 thr3 - // 12 =0 f 0 4 3 3 3 3 - // 12 >0 f 0 4 3 3 3 3+ - // 13 >0 f 1 3 4 3 3 3+ - // 14 >0 f 2 2 4 4 3 3+ - // 15 >0 f 3 1 4 4 4 3+ - // 15 =0 f 3 1 4 4 4 3 - // - // 12 =0 t 4 0 3 3 3 3 - // 12 >0 t 4 0 3+ 3 3 3 - // 13 >0 t 3 1 3+ 3 3 4 - // 14 >0 t 2 2 3+ 3 4 4 - // 15 >0 t 1 3 3+ 4 4 4 - // 15 =0 t 1 3 3 4 4 4 - - // As indicated by the table above, load is balanced as equally - // as possible, even in the presence of an edge case. - - // First, we must differentiate between cases where the leftover - // "edge" case (n_bf_left) should be allocated to a thread partition - // at the low end of the index range or the high end. - - if ( handle_edge_low == FALSE ) - { - // Notice that if all threads receive the same number of - // block_factors, those threads are considered "high" and - // the "low" thread group is empty. - dim_t n_th_lo = n_bf_whole % n_way; - //dim_t n_th_hi = n_way - n_th_lo; - - // If some partitions must have more block_factors than others - // assign the slightly larger partitions to lower index threads. - if ( n_th_lo != 0 ) n_bf_lo += 1; - - // Compute the actual widths (in units of rows/columns) of - // individual threads in the low and high groups. - dim_t size_lo = n_bf_lo * bf; - dim_t size_hi = n_bf_hi * bf; - - // Precompute the starting indices of the low and high groups. - dim_t lo_start = all_start; - dim_t hi_start = all_start + n_th_lo * size_lo; - - // Compute the start and end of individual threads' ranges - // as a function of their work_ids and also the group to which - // they belong (low or high). - if ( work_id < n_th_lo ) - { - *start = lo_start + (work_id ) * size_lo; - *end = lo_start + (work_id+1) * size_lo; - } - else // if ( n_th_lo <= work_id ) - { - *start = hi_start + (work_id-n_th_lo ) * size_hi; - *end = hi_start + (work_id-n_th_lo+1) * size_hi; - - // Since the edge case is being allocated to the high - // end of the index range, we have to advance the last - // thread's end. - if ( work_id == n_way - 1 ) *end += n_bf_left; - } - } - else // if ( handle_edge_low == TRUE ) - { - // Notice that if all threads receive the same number of - // block_factors, those threads are considered "low" and - // the "high" thread group is empty. - dim_t n_th_hi = n_bf_whole % n_way; - dim_t n_th_lo = n_way - n_th_hi; - - // If some partitions must have more block_factors than others - // assign the slightly larger partitions to higher index threads. - if ( n_th_hi != 0 ) n_bf_hi += 1; - - // Compute the actual widths (in units of rows/columns) of - // individual threads in the low and high groups. - dim_t size_lo = n_bf_lo * bf; - dim_t size_hi = n_bf_hi * bf; - - // Precompute the starting indices of the low and high groups. - dim_t lo_start = all_start; - dim_t hi_start = all_start + n_th_lo * size_lo - + n_bf_left; - - // Compute the start and end of individual threads' ranges - // as a function of their work_ids and also the group to which - // they belong (low or high). - if ( work_id < n_th_lo ) - { - *start = lo_start + (work_id ) * size_lo; - *end = lo_start + (work_id+1) * size_lo; - - // Since the edge case is being allocated to the low - // end of the index range, we have to advance the - // starts/ends accordingly. - if ( work_id == 0 ) *end += n_bf_left; - else { *start += n_bf_left; - *end += n_bf_left; } - } - else // if ( n_th_lo <= work_id ) - { - *start = hi_start + (work_id-n_th_lo ) * size_hi; - *end = hi_start + (work_id-n_th_lo+1) * size_hi; - } - } -} - -siz_t bli_thread_range_l2r - ( - const thrinfo_t* thr, - const obj_t* a, - const blksz_t* bmult, - dim_t* start, - dim_t* end - ) -{ - num_t dt = bli_obj_dt( a ); - dim_t m = bli_obj_length_after_trans( a ); - dim_t n = bli_obj_width_after_trans( a ); - dim_t bf = bli_blksz_get_def( dt, bmult ); - - bli_thread_range_sub( thr, n, bf, - FALSE, start, end ); - - return m * ( *end - *start ); -} - -siz_t bli_thread_range_r2l - ( - const thrinfo_t* thr, - const obj_t* a, - const blksz_t* bmult, - dim_t* start, - dim_t* end - ) -{ - num_t dt = bli_obj_dt( a ); - dim_t m = bli_obj_length_after_trans( a ); - dim_t n = bli_obj_width_after_trans( a ); - dim_t bf = bli_blksz_get_def( dt, bmult ); - - bli_thread_range_sub( thr, n, bf, - TRUE, start, end ); - - return m * ( *end - *start ); -} - -siz_t bli_thread_range_t2b - ( - const thrinfo_t* thr, - const obj_t* a, - const blksz_t* bmult, - dim_t* start, - dim_t* end - ) -{ - num_t dt = bli_obj_dt( a ); - dim_t m = bli_obj_length_after_trans( a ); - dim_t n = bli_obj_width_after_trans( a ); - dim_t bf = bli_blksz_get_def( dt, bmult ); - - bli_thread_range_sub( thr, m, bf, - FALSE, start, end ); - - return n * ( *end - *start ); -} - -siz_t bli_thread_range_b2t - ( - const thrinfo_t* thr, - const obj_t* a, - const blksz_t* bmult, - dim_t* start, - dim_t* end - ) -{ - num_t dt = bli_obj_dt( a ); - dim_t m = bli_obj_length_after_trans( a ); - dim_t n = bli_obj_width_after_trans( a ); - dim_t bf = bli_blksz_get_def( dt, bmult ); - - bli_thread_range_sub( thr, m, bf, - TRUE, start, end ); - - return n * ( *end - *start ); -} - -// ----------------------------------------------------------------------------- - -dim_t bli_thread_range_width_l - ( - doff_t diagoff_j, - dim_t m, - dim_t n_j, - dim_t j, - dim_t n_way, - dim_t bf, - dim_t bf_left, - double area_per_thr, - bool handle_edge_low - ) -{ - dim_t width; - - // In this function, we assume that we are somewhere in the process of - // partitioning an m x n lower-stored region (with arbitrary diagonal - // offset) n_ways along the n dimension (into column panels). The value - // j identifies the left-to-right subpartition index (from 0 to n_way-1) - // of the subpartition whose width we are about to compute using the - // area per thread determined by the caller. n_j is the number of - // columns in the remaining region of the matrix being partitioned, - // and diagoff_j is that region's diagonal offset. - - // If this is the last subpartition, the width is simply equal to n_j. - // Note that this statement handles cases where the "edge case" (if - // one exists) is assigned to the high end of the index range (ie: - // handle_edge_low == FALSE). - if ( j == n_way - 1 ) return n_j; - - // At this point, we know there are at least two subpartitions left. - // We also know that IF the submatrix contains a completely dense - // rectangular submatrix, it will occur BEFORE the triangular (or - // trapezoidal) part. - - // Here, we implement a somewhat minor load balancing optimization - // that ends up getting employed only for relatively small matrices. - // First, recall that all subpartition widths will be some multiple - // of the blocking factor bf, except perhaps either the first or last - // subpartition, which will receive the edge case, if it exists. - // Also recall that j represents the current thread (or thread group, - // or "caucus") for which we are computing a subpartition width. - // If n_j is sufficiently small that we can only allocate bf columns - // to each of the remaining threads, then we set the width to bf. We - // do not allow the subpartition width to be less than bf, so, under - // some conditions, if n_j is small enough, some of the reamining - // threads may not get any work. For the purposes of this lower bound - // on work (ie: width >= bf), we allow the edge case to count as a - // "full" set of bf columns. - { - dim_t n_j_bf = n_j / bf + ( bf_left > 0 ? 1 : 0 ); - - if ( n_j_bf <= n_way - j ) - { - if ( j == 0 && handle_edge_low ) - width = ( bf_left > 0 ? bf_left : bf ); - else - width = bf; - - // Make sure that the width does not exceed n_j. This would - // occur if and when n_j_bf < n_way - j; that is, when the - // matrix being partitioned is sufficiently small relative to - // n_way such that there is not even enough work for every - // (remaining) thread to get bf (or bf_left) columns. The - // net effect of this safeguard is that some threads may get - // assigned empty ranges (ie: no work), which of course must - // happen in some situations. - if ( width > n_j ) width = n_j; - - return width; - } - } - - // This block computes the width assuming that we are entirely within - // a dense rectangle that precedes the triangular (or trapezoidal) - // part. - { - // First compute the width of the current panel under the - // assumption that the diagonal offset would not intersect. - width = ( dim_t )bli_round( ( double )area_per_thr / ( double )m ); - - // Adjust the width, if necessary. Specifically, we may need - // to allocate the edge case to the first subpartition, if - // requested; otherwise, we just need to ensure that the - // subpartition is a multiple of the blocking factor. - if ( j == 0 && handle_edge_low ) - { - if ( width % bf != bf_left ) width += bf_left - ( width % bf ); - } - else // if interior case - { - // Round up to the next multiple of the blocking factor. - //if ( width % bf != 0 ) width += bf - ( width % bf ); - // Round to the nearest multiple of the blocking factor. - if ( width % bf != 0 ) width = bli_round_to_mult( width, bf ); - } - } - - // We need to recompute width if the panel, according to the width - // as currently computed, would intersect the diagonal. - if ( diagoff_j < width ) - { - dim_t offm_inc, offn_inc; - - // Prune away the unstored region above the diagonal, if it exists. - // Note that the entire region was pruned initially, so we know that - // we don't need to try to prune the right side. (Also, we discard - // the offset deltas since we don't need to actually index into the - // subpartition.) - bli_prune_unstored_region_top_l( &diagoff_j, &m, &n_j, &offm_inc ); - //bli_prune_unstored_region_right_l( &diagoff_j, &m, &n_j, &offn_inc ); - - // We don't need offm_inc, offn_inc here. These statements should - // prevent compiler warnings. - ( void )offm_inc; - ( void )offn_inc; - - // Prepare to solve a quadratic equation to find the width of the - // current (jth) subpartition given the m dimension, diagonal offset, - // and area. - // NOTE: We know that the +/- in the quadratic formula must be a + - // here because we know that the desired solution (the subpartition - // width) will be smaller than (m + diagoff), not larger. If you - // don't believe me, draw a picture! - const double a = -0.5; - const double b = ( double )m + ( double )diagoff_j + 0.5; - const double c = -0.5 * ( ( double )diagoff_j * - ( ( double )diagoff_j + 1.0 ) - ) - area_per_thr; - const double r = b * b - 4.0 * a * c; - - // If the quadratic solution is not imaginary, round it and use that - // as our width, but make sure it didn't round to zero. Otherwise, - // discard the quadratic solution and leave width, as previously - // computed, unchanged. - if ( r >= 0.0 ) - { - const double x = ( -b + sqrt( r ) ) / ( 2.0 * a ); - - width = ( dim_t )bli_round( x ); - if ( width == 0 ) width = 1; - } - - // Adjust the width, if necessary. - if ( j == 0 && handle_edge_low ) - { - if ( width % bf != bf_left ) width += bf_left - ( width % bf ); - } - else // if interior case - { - // Round up to the next multiple of the blocking factor. - //if ( width % bf != 0 ) width += bf - ( width % bf ); - // Round to the nearest multiple of the blocking factor. - if ( width % bf != 0 ) width = bli_round_to_mult( width, bf ); - } - } - - // Make sure that the width, after being adjusted, does not cause the - // subpartition to exceed n_j. - if ( width > n_j ) width = n_j; - - return width; -} - -siz_t bli_find_area_trap_l - ( - dim_t m, - dim_t n, - doff_t diagoff - ) -{ - dim_t offm_inc = 0; - dim_t offn_inc = 0; - double tri_area; - double area; - - // Prune away any rectangular region above where the diagonal - // intersects the left edge of the subpartition, if it exists. - bli_prune_unstored_region_top_l( &diagoff, &m, &n, &offm_inc ); - - // Prune away any rectangular region to the right of where the - // diagonal intersects the bottom edge of the subpartition, if - // it exists. (This shouldn't ever be needed, since the caller - // would presumably have already performed rightward pruning, - // but it's here just in case.) - bli_prune_unstored_region_right_l( &diagoff, &m, &n, &offn_inc ); - - ( void )offm_inc; - ( void )offn_inc; - - // Compute the area of the empty triangle so we can subtract it - // from the area of the rectangle that bounds the subpartition. - if ( bli_intersects_diag_n( diagoff, m, n ) ) - { - double tri_dim = ( double )( n - diagoff - 1 ); - tri_area = tri_dim * ( tri_dim + 1.0 ) / 2.0; - } - else - { - // If the diagonal does not intersect the trapezoid, then - // we can compute the area as a simple rectangle. - tri_area = 0.0; - } - - area = ( double )m * ( double )n - tri_area; - - return ( siz_t )area; -} - -// ----------------------------------------------------------------------------- - -siz_t bli_thread_range_weighted_sub - ( - const thrinfo_t* thread, - doff_t diagoff, - uplo_t uplo, - dim_t m, - dim_t n, - dim_t bf, - bool handle_edge_low, - dim_t* j_start_thr, - dim_t* j_end_thr - ) -{ - dim_t n_way = bli_thrinfo_n_way( thread ); - dim_t my_id = bli_thrinfo_work_id( thread ); - - dim_t bf_left = n % bf; - - dim_t j; - - dim_t off_j; - doff_t diagoff_j; - dim_t n_left; - - dim_t width_j; - - dim_t offm_inc, offn_inc; - - double tri_dim, tri_area; - double area_total, area_per_thr; - - siz_t area = 0; - - // In this function, we assume that the caller has already determined - // that (a) the diagonal intersects the submatrix, and (b) the submatrix - // is either lower- or upper-stored. - - if ( bli_is_lower( uplo ) ) - { - // Prune away the unstored region above the diagonal, if it exists, - // and then to the right of where the diagonal intersects the bottom, - // if it exists. (Also, we discard the offset deltas since we don't - // need to actually index into the subpartition.) - bli_prune_unstored_region_top_l( &diagoff, &m, &n, &offm_inc ); - bli_prune_unstored_region_right_l( &diagoff, &m, &n, &offn_inc ); - - // We don't need offm_inc, offn_inc here. These statements should - // prevent compiler warnings. - ( void )offm_inc; - ( void )offn_inc; - - // Now that pruning has taken place, we know that diagoff >= 0. - - // Compute the total area of the submatrix, accounting for the - // location of the diagonal, and divide it by the number of ways - // of parallelism. - tri_dim = ( double )( n - diagoff - 1 ); - tri_area = tri_dim * ( tri_dim + 1.0 ) / 2.0; - area_total = ( double )m * ( double )n - tri_area; - area_per_thr = area_total / ( double )n_way; - - // Initialize some variables prior to the loop: the offset to the - // current subpartition, the remainder of the n dimension, and - // the diagonal offset of the current subpartition. - off_j = 0; - diagoff_j = diagoff; - n_left = n; - - // Iterate over the subpartition indices corresponding to each - // thread/caucus participating in the n_way parallelism. - for ( j = 0; j < n_way; ++j ) - { - // Compute the width of the jth subpartition, taking the - // current diagonal offset into account, if needed. - width_j = - bli_thread_range_width_l - ( - diagoff_j, m, n_left, - j, n_way, - bf, bf_left, - area_per_thr, - handle_edge_low - ); - - // If the current thread belongs to caucus j, this is his - // subpartition. So we compute the implied index range and - // end our search. - if ( j == my_id ) - { - *j_start_thr = off_j; - *j_end_thr = off_j + width_j; - - area = bli_find_area_trap_l( m, width_j, diagoff_j ); - - break; - } - - // Shift the current subpartition's starting and diagonal offsets, - // as well as the remainder of the n dimension, according to the - // computed width, and then iterate to the next subpartition. - off_j += width_j; - diagoff_j -= width_j; - n_left -= width_j; - } - } - else // if ( bli_is_upper( uplo ) ) - { - // Express the upper-stored case in terms of the lower-stored case. - - // First, we convert the upper-stored trapezoid to an equivalent - // lower-stored trapezoid by rotating it 180 degrees. - bli_rotate180_trapezoid( &diagoff, &uplo, &m, &n ); - - // Now that the trapezoid is "flipped" in the n dimension, negate - // the bool that encodes whether to handle the edge case at the - // low (or high) end of the index range. - bli_toggle_bool( &handle_edge_low ); - - // Compute the appropriate range for the rotated trapezoid. - area = bli_thread_range_weighted_sub - ( - thread, diagoff, uplo, m, n, bf, - handle_edge_low, - j_start_thr, j_end_thr - ); - - // Reverse the indexing basis for the subpartition ranges so that - // the indices, relative to left-to-right iteration through the - // unrotated upper-stored trapezoid, map to the correct columns - // (relative to the diagonal). This amounts to subtracting the - // range from n. - bli_reverse_index_direction( n, j_start_thr, j_end_thr ); - } - - return area; -} - -siz_t bli_thread_range_mdim - ( - dir_t direct, - const thrinfo_t* thr, - const obj_t* a, - const obj_t* b, - const obj_t* c, - const cntl_t* cntl, - const cntx_t* cntx, - dim_t* start, - dim_t* end - ) -{ - bszid_t bszid = bli_cntl_bszid( cntl ); - opid_t family = bli_cntl_family( cntl ); - - // This is part of trsm's current implementation, whereby right side - // cases are implemented in left-side micro-kernels, which requires - // we swap the usage of the register blocksizes for the purposes of - // packing A and B. - if ( family == BLIS_TRSM ) - { - if ( bli_obj_root_is_triangular( a ) ) bszid = BLIS_MR; - else bszid = BLIS_NR; - } - - const blksz_t* bmult = bli_cntx_get_bmult( bszid, cntx ); - const obj_t* x; - bool use_weighted; - - // Use the operation family to choose the one of the two matrices - // being partitioned that potentially has structure, and also to - // decide whether or not we need to use weighted range partitioning. - // NOTE: It's important that we use non-weighted range partitioning - // for hemm and symm (ie: the gemm family) because the weighted - // function will mistakenly skip over unstored regions of the - // structured matrix, even though they represent part of that matrix - // that will be dense and full (after packing). - if ( family == BLIS_GEMM ) { x = a; use_weighted = FALSE; } - else if ( family == BLIS_GEMMT ) { x = c; use_weighted = TRUE; } - else if ( family == BLIS_TRMM ) { x = a; use_weighted = TRUE; } - else /*family == BLIS_TRSM*/ { x = a; use_weighted = FALSE; } - - if ( use_weighted ) - { - if ( direct == BLIS_FWD ) - return bli_thread_range_weighted_t2b( thr, x, bmult, start, end ); - else - return bli_thread_range_weighted_b2t( thr, x, bmult, start, end ); - } - else - { - if ( direct == BLIS_FWD ) - return bli_thread_range_t2b( thr, x, bmult, start, end ); - else - return bli_thread_range_b2t( thr, x, bmult, start, end ); - } -} - -siz_t bli_thread_range_ndim - ( - dir_t direct, - const thrinfo_t* thr, - const obj_t* a, - const obj_t* b, - const obj_t* c, - const cntl_t* cntl, - const cntx_t* cntx, - dim_t* start, - dim_t* end - ) -{ - bszid_t bszid = bli_cntl_bszid( cntl ); - opid_t family = bli_cntl_family( cntl ); - - // This is part of trsm's current implementation, whereby right side - // cases are implemented in left-side micro-kernels, which requires - // we swap the usage of the register blocksizes for the purposes of - // packing A and B. - if ( family == BLIS_TRSM ) - { - if ( bli_obj_root_is_triangular( b ) ) bszid = BLIS_MR; - else bszid = BLIS_NR; - } - - const blksz_t* bmult = bli_cntx_get_bmult( bszid, cntx ); - const obj_t* x; - bool use_weighted; - - // Use the operation family to choose the one of the two matrices - // being partitioned that potentially has structure, and also to - // decide whether or not we need to use weighted range partitioning. - // NOTE: It's important that we use non-weighted range partitioning - // for hemm and symm (ie: the gemm family) because the weighted - // function will mistakenly skip over unstored regions of the - // structured matrix, even though they represent part of that matrix - // that will be dense and full (after packing). - if ( family == BLIS_GEMM ) { x = b; use_weighted = FALSE; } - else if ( family == BLIS_GEMMT ) { x = c; use_weighted = TRUE; } - else if ( family == BLIS_TRMM ) { x = b; use_weighted = TRUE; } - else /*family == BLIS_TRSM*/ { x = b; use_weighted = FALSE; } - - if ( use_weighted ) - { - if ( direct == BLIS_FWD ) - return bli_thread_range_weighted_l2r( thr, x, bmult, start, end ); - else - return bli_thread_range_weighted_r2l( thr, x, bmult, start, end ); - } - else - { - if ( direct == BLIS_FWD ) - return bli_thread_range_l2r( thr, x, bmult, start, end ); - else - return bli_thread_range_r2l( thr, x, bmult, start, end ); - } -} - -siz_t bli_thread_range_weighted_l2r - ( - const thrinfo_t* thr, - const obj_t* a, - const blksz_t* bmult, - dim_t* start, - dim_t* end - ) -{ - siz_t area; - - // This function assigns area-weighted ranges in the n dimension - // where the total range spans 0 to n-1 with 0 at the left end and - // n-1 at the right end. - - if ( bli_obj_intersects_diag( a ) && - bli_obj_is_upper_or_lower( a ) ) - { - num_t dt = bli_obj_dt( a ); - doff_t diagoff = bli_obj_diag_offset( a ); - uplo_t uplo = bli_obj_uplo( a ); - dim_t m = bli_obj_length( a ); - dim_t n = bli_obj_width( a ); - dim_t bf = bli_blksz_get_def( dt, bmult ); - - // Support implicit transposition. - if ( bli_obj_has_trans( a ) ) - { - bli_reflect_about_diag( &diagoff, &uplo, &m, &n ); - } - - area = - bli_thread_range_weighted_sub - ( - thr, diagoff, uplo, m, n, bf, - FALSE, start, end - ); - } - else // if dense or zeros - { - area = bli_thread_range_l2r - ( - thr, a, bmult, - start, end - ); - } - - return area; -} - -siz_t bli_thread_range_weighted_r2l - ( - const thrinfo_t* thr, - const obj_t* a, - const blksz_t* bmult, - dim_t* start, - dim_t* end - ) -{ - siz_t area; - - // This function assigns area-weighted ranges in the n dimension - // where the total range spans 0 to n-1 with 0 at the right end and - // n-1 at the left end. - - if ( bli_obj_intersects_diag( a ) && - bli_obj_is_upper_or_lower( a ) ) - { - num_t dt = bli_obj_dt( a ); - doff_t diagoff = bli_obj_diag_offset( a ); - uplo_t uplo = bli_obj_uplo( a ); - dim_t m = bli_obj_length( a ); - dim_t n = bli_obj_width( a ); - dim_t bf = bli_blksz_get_def( dt, bmult ); - - // Support implicit transposition. - if ( bli_obj_has_trans( a ) ) - { - bli_reflect_about_diag( &diagoff, &uplo, &m, &n ); - } - - bli_rotate180_trapezoid( &diagoff, &uplo, &m, &n ); - - area = - bli_thread_range_weighted_sub - ( - thr, diagoff, uplo, m, n, bf, - TRUE, start, end - ); - } - else // if dense or zeros - { - area = bli_thread_range_r2l - ( - thr, a, bmult, - start, end - ); - } - - return area; -} - -siz_t bli_thread_range_weighted_t2b - ( - const thrinfo_t* thr, - const obj_t* a, - const blksz_t* bmult, - dim_t* start, - dim_t* end - ) -{ - siz_t area; - - // This function assigns area-weighted ranges in the m dimension - // where the total range spans 0 to m-1 with 0 at the top end and - // m-1 at the bottom end. - - if ( bli_obj_intersects_diag( a ) && - bli_obj_is_upper_or_lower( a ) ) - { - num_t dt = bli_obj_dt( a ); - doff_t diagoff = bli_obj_diag_offset( a ); - uplo_t uplo = bli_obj_uplo( a ); - dim_t m = bli_obj_length( a ); - dim_t n = bli_obj_width( a ); - dim_t bf = bli_blksz_get_def( dt, bmult ); - - // Support implicit transposition. - if ( bli_obj_has_trans( a ) ) - { - bli_reflect_about_diag( &diagoff, &uplo, &m, &n ); - } - - bli_reflect_about_diag( &diagoff, &uplo, &m, &n ); - - area = - bli_thread_range_weighted_sub - ( - thr, diagoff, uplo, m, n, bf, - FALSE, start, end - ); - } - else // if dense or zeros - { - area = bli_thread_range_t2b - ( - thr, a, bmult, - start, end - ); - } - - return area; -} - -siz_t bli_thread_range_weighted_b2t - ( - const thrinfo_t* thr, - const obj_t* a, - const blksz_t* bmult, - dim_t* start, - dim_t* end - ) -{ - siz_t area; - - // This function assigns area-weighted ranges in the m dimension - // where the total range spans 0 to m-1 with 0 at the bottom end and - // m-1 at the top end. - - if ( bli_obj_intersects_diag( a ) && - bli_obj_is_upper_or_lower( a ) ) - { - num_t dt = bli_obj_dt( a ); - doff_t diagoff = bli_obj_diag_offset( a ); - uplo_t uplo = bli_obj_uplo( a ); - dim_t m = bli_obj_length( a ); - dim_t n = bli_obj_width( a ); - dim_t bf = bli_blksz_get_def( dt, bmult ); - - // Support implicit transposition. - if ( bli_obj_has_trans( a ) ) - { - bli_reflect_about_diag( &diagoff, &uplo, &m, &n ); - } - - bli_reflect_about_diag( &diagoff, &uplo, &m, &n ); - - bli_rotate180_trapezoid( &diagoff, &uplo, &m, &n ); - - area = bli_thread_range_weighted_sub - ( - thr, diagoff, uplo, m, n, bf, - TRUE, start, end - ); - } - else // if dense or zeros - { - area = bli_thread_range_b2t - ( - thr, a, bmult, - start, end - ); - } - - return area; -} - -// ----------------------------------------------------------------------------- - void bli_prime_factorization( dim_t n, bli_prime_factors_t* factors ) { factors->n = n; diff --git a/frame/thread/bli_thread.h b/frame/thread/bli_thread.h index e61fc8b892..5002672dc4 100644 --- a/frame/thread/bli_thread.h +++ b/frame/thread/bli_thread.h @@ -56,6 +56,8 @@ typedef void (*thread_func_t)( thrcomm_t* gl_comm, dim_t tid, const void* params void bli_thread_init( void ); void bli_thread_finalize( void ); +// ----------------------------------------------------------------------------- + BLIS_EXPORT_BLIS void bli_thread_launch ( timpl_t ti, @@ -64,91 +66,6 @@ BLIS_EXPORT_BLIS void bli_thread_launch const void* params ); -// Thread range-related prototypes. - -BLIS_EXPORT_BLIS void bli_thread_range_sub - ( - const thrinfo_t* thread, - dim_t n, - dim_t bf, - bool handle_edge_low, - dim_t* start, - dim_t* end - ); - -#undef GENPROT -#define GENPROT( opname ) \ -\ -siz_t PASTEMAC0( opname ) \ - ( \ - dir_t direct, \ - const thrinfo_t* thr, \ - const obj_t* a, \ - const obj_t* b, \ - const obj_t* c, \ - const cntl_t* cntl, \ - const cntx_t* cntx, \ - dim_t* start, \ - dim_t* end \ - ); - -GENPROT( thread_range_mdim ) -GENPROT( thread_range_ndim ) - -#undef GENPROT -#define GENPROT( opname ) \ -\ -siz_t PASTEMAC0( opname ) \ - ( \ - const thrinfo_t* thr, \ - const obj_t* a, \ - const blksz_t* bmult, \ - dim_t* start, \ - dim_t* end \ - ); - -GENPROT( thread_range_l2r ) -GENPROT( thread_range_r2l ) -GENPROT( thread_range_t2b ) -GENPROT( thread_range_b2t ) - -GENPROT( thread_range_weighted_l2r ) -GENPROT( thread_range_weighted_r2l ) -GENPROT( thread_range_weighted_t2b ) -GENPROT( thread_range_weighted_b2t ) - - -dim_t bli_thread_range_width_l - ( - doff_t diagoff_j, - dim_t m, - dim_t n_j, - dim_t j, - dim_t n_way, - dim_t bf, - dim_t bf_left, - double area_per_thr, - bool handle_edge_low - ); -siz_t bli_find_area_trap_l - ( - dim_t m, - dim_t n, - doff_t diagoff - ); -siz_t bli_thread_range_weighted_sub - ( - const thrinfo_t* thread, - doff_t diagoff, - uplo_t uplo, - dim_t m, - dim_t n, - dim_t bf, - bool handle_edge_low, - dim_t* j_start_thr, - dim_t* j_end_thr - ); - // ----------------------------------------------------------------------------- // Factorization and partitioning prototypes @@ -212,98 +129,5 @@ BLIS_EXPORT_BLIS void bli_thread_set_thread_impl( timpl_t ti ); void bli_thread_init_rntm_from_env( rntm_t* rntm ); -// ----------------------------------------------------------------------------- - -BLIS_INLINE void bli_thread_range_jrir_rr - ( - const thrinfo_t* thread, - dim_t n, - dim_t bf, - bool handle_edge_low, - dim_t* start, - dim_t* end, - dim_t* inc - ) -{ - // Use interleaved partitioning of jr/ir loops. - *start = bli_thrinfo_work_id( thread ); - *inc = bli_thrinfo_n_way( thread ); - *end = n; -} - -BLIS_INLINE void bli_thread_range_jrir_sl - ( - const thrinfo_t* thread, - dim_t n, - dim_t bf, - bool handle_edge_low, - dim_t* start, - dim_t* end, - dim_t* inc - ) -{ - // Use contiguous slab partitioning of jr/ir loops. - bli_thread_range_sub( thread, n, bf, handle_edge_low, start, end ); - *inc = 1; -} - -BLIS_INLINE void bli_thread_range_jrir - ( - const thrinfo_t* thread, - dim_t n, - dim_t bf, - bool handle_edge_low, - dim_t* start, - dim_t* end, - dim_t* inc - ) -{ - // Define a general-purpose version of bli_thread_range_jrir() whose - // definition depends on whether slab or round-robin partitioning was - // requested at configure-time. -#ifdef BLIS_ENABLE_JRIR_SLAB - bli_thread_range_jrir_sl( thread, n, bf, handle_edge_low, start, end, inc ); -#else - bli_thread_range_jrir_rr( thread, n, bf, handle_edge_low, start, end, inc ); -#endif -} - -#if 0 -BLIS_INLINE void bli_thread_range_weighted_jrir - ( - thrinfo_t* thread, - doff_t diagoff, - uplo_t uplo, - dim_t m, - dim_t n, - dim_t bf, - bool handle_edge_low, - dim_t* start, - dim_t* end, - dim_t* inc - ) -{ -#ifdef BLIS_ENABLE_JRIR_SLAB - - // Use contiguous slab partitioning for jr/ir loops. - bli_thread_range_weighted_sub( thread, diagoff, uplo, m, n, bf, - handle_edge_low, start, end ); - - *start = *start / bf; *inc = 1; - - if ( *end % bf ) *end = *end / bf + 1; - else *end = *end / bf; - -#else - // Use interleaved partitioning of jr/ir loops. - *start = bli_thrinfo_work_id( thread ); - *inc = bli_thrinfo_n_way( thread ); - *end = n; - -#endif -} #endif - -#endif - diff --git a/frame/thread/bli_thread_range.c b/frame/thread/bli_thread_range.c new file mode 100644 index 0000000000..a28e529b02 --- /dev/null +++ b/frame/thread/bli_thread_range.c @@ -0,0 +1,1121 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bli_thread_range_sub + ( + const thrinfo_t* thread, + dim_t n, + dim_t bf, + bool handle_edge_low, + dim_t* start, + dim_t* end + ) +{ + dim_t n_way = bli_thrinfo_n_way( thread ); + + if ( n_way == 1 ) { *start = 0; *end = n; return; } + + dim_t work_id = bli_thrinfo_work_id( thread ); + + dim_t all_start = 0; + dim_t all_end = n; + + dim_t size = all_end - all_start; + + dim_t n_bf_whole = size / bf; + dim_t n_bf_left = size % bf; + + dim_t n_bf_lo = n_bf_whole / n_way; + dim_t n_bf_hi = n_bf_whole / n_way; + + // In this function, we partition the space between all_start and + // all_end into n_way partitions, each a multiple of block_factor + // with the exception of the one partition that recieves the + // "edge" case (if applicable). + // + // Here are examples of various thread partitionings, in units of + // the block_factor, when n_way = 4. (A '+' indicates the thread + // that receives the leftover edge case (ie: n_bf_left extra + // rows/columns in its sub-range). + // (all_start ... all_end) + // n_bf_whole _left hel n_th_lo _hi thr0 thr1 thr2 thr3 + // 12 =0 f 0 4 3 3 3 3 + // 12 >0 f 0 4 3 3 3 3+ + // 13 >0 f 1 3 4 3 3 3+ + // 14 >0 f 2 2 4 4 3 3+ + // 15 >0 f 3 1 4 4 4 3+ + // 15 =0 f 3 1 4 4 4 3 + // + // 12 =0 t 4 0 3 3 3 3 + // 12 >0 t 4 0 3+ 3 3 3 + // 13 >0 t 3 1 3+ 3 3 4 + // 14 >0 t 2 2 3+ 3 4 4 + // 15 >0 t 1 3 3+ 4 4 4 + // 15 =0 t 1 3 3 4 4 4 + + // As indicated by the table above, load is balanced as equally + // as possible, even in the presence of an edge case. + + // First, we must differentiate between cases where the leftover + // "edge" case (n_bf_left) should be allocated to a thread partition + // at the low end of the index range or the high end. + + if ( handle_edge_low == FALSE ) + { + // Notice that if all threads receive the same number of + // block_factors, those threads are considered "high" and + // the "low" thread group is empty. + dim_t n_th_lo = n_bf_whole % n_way; + //dim_t n_th_hi = n_way - n_th_lo; + + // If some partitions must have more block_factors than others + // assign the slightly larger partitions to lower index threads. + if ( n_th_lo != 0 ) n_bf_lo += 1; + + // Compute the actual widths (in units of rows/columns) of + // individual threads in the low and high groups. + dim_t size_lo = n_bf_lo * bf; + dim_t size_hi = n_bf_hi * bf; + + // Precompute the starting indices of the low and high groups. + dim_t lo_start = all_start; + dim_t hi_start = all_start + n_th_lo * size_lo; + + // Compute the start and end of individual threads' ranges + // as a function of their work_ids and also the group to which + // they belong (low or high). + if ( work_id < n_th_lo ) + { + *start = lo_start + (work_id ) * size_lo; + *end = lo_start + (work_id+1) * size_lo; + } + else // if ( n_th_lo <= work_id ) + { + *start = hi_start + (work_id-n_th_lo ) * size_hi; + *end = hi_start + (work_id-n_th_lo+1) * size_hi; + + // Since the edge case is being allocated to the high + // end of the index range, we have to advance the last + // thread's end. + if ( work_id == n_way - 1 ) *end += n_bf_left; + } + } + else // if ( handle_edge_low == TRUE ) + { + // Notice that if all threads receive the same number of + // block_factors, those threads are considered "low" and + // the "high" thread group is empty. + dim_t n_th_hi = n_bf_whole % n_way; + dim_t n_th_lo = n_way - n_th_hi; + + // If some partitions must have more block_factors than others + // assign the slightly larger partitions to higher index threads. + if ( n_th_hi != 0 ) n_bf_hi += 1; + + // Compute the actual widths (in units of rows/columns) of + // individual threads in the low and high groups. + dim_t size_lo = n_bf_lo * bf; + dim_t size_hi = n_bf_hi * bf; + + // Precompute the starting indices of the low and high groups. + dim_t lo_start = all_start; + dim_t hi_start = all_start + n_th_lo * size_lo + + n_bf_left; + + // Compute the start and end of individual threads' ranges + // as a function of their work_ids and also the group to which + // they belong (low or high). + if ( work_id < n_th_lo ) + { + *start = lo_start + (work_id ) * size_lo; + *end = lo_start + (work_id+1) * size_lo; + + // Since the edge case is being allocated to the low + // end of the index range, we have to advance the + // starts/ends accordingly. + if ( work_id == 0 ) *end += n_bf_left; + else { *start += n_bf_left; + *end += n_bf_left; } + } + else // if ( n_th_lo <= work_id ) + { + *start = hi_start + (work_id-n_th_lo ) * size_hi; + *end = hi_start + (work_id-n_th_lo+1) * size_hi; + } + } +} + +// ----------------------------------------------------------------------------- + +siz_t bli_thread_range_l2r + ( + const thrinfo_t* thr, + const obj_t* a, + const blksz_t* bmult, + dim_t* start, + dim_t* end + ) +{ + num_t dt = bli_obj_dt( a ); + dim_t m = bli_obj_length_after_trans( a ); + dim_t n = bli_obj_width_after_trans( a ); + dim_t bf = bli_blksz_get_def( dt, bmult ); + + bli_thread_range_sub( thr, n, bf, + FALSE, start, end ); + + return m * ( *end - *start ); +} + +siz_t bli_thread_range_r2l + ( + const thrinfo_t* thr, + const obj_t* a, + const blksz_t* bmult, + dim_t* start, + dim_t* end + ) +{ + num_t dt = bli_obj_dt( a ); + dim_t m = bli_obj_length_after_trans( a ); + dim_t n = bli_obj_width_after_trans( a ); + dim_t bf = bli_blksz_get_def( dt, bmult ); + + bli_thread_range_sub( thr, n, bf, + TRUE, start, end ); + + return m * ( *end - *start ); +} + +siz_t bli_thread_range_t2b + ( + const thrinfo_t* thr, + const obj_t* a, + const blksz_t* bmult, + dim_t* start, + dim_t* end + ) +{ + num_t dt = bli_obj_dt( a ); + dim_t m = bli_obj_length_after_trans( a ); + dim_t n = bli_obj_width_after_trans( a ); + dim_t bf = bli_blksz_get_def( dt, bmult ); + + bli_thread_range_sub( thr, m, bf, + FALSE, start, end ); + + return n * ( *end - *start ); +} + +siz_t bli_thread_range_b2t + ( + const thrinfo_t* thr, + const obj_t* a, + const blksz_t* bmult, + dim_t* start, + dim_t* end + ) +{ + num_t dt = bli_obj_dt( a ); + dim_t m = bli_obj_length_after_trans( a ); + dim_t n = bli_obj_width_after_trans( a ); + dim_t bf = bli_blksz_get_def( dt, bmult ); + + bli_thread_range_sub( thr, m, bf, + TRUE, start, end ); + + return n * ( *end - *start ); +} + +// ----------------------------------------------------------------------------- + +dim_t bli_thread_range_width_l + ( + doff_t diagoff_j, + dim_t m, + dim_t n_j, + dim_t j, + dim_t n_way, + dim_t bf, + dim_t bf_left, + double area_per_thr, + bool handle_edge_low + ) +{ + dim_t width; + + // In this function, we assume that we are somewhere in the process of + // partitioning an m x n lower-stored region (with arbitrary diagonal + // offset) n_ways along the n dimension (into column panels). The value + // j identifies the left-to-right subpartition index (from 0 to n_way-1) + // of the subpartition whose width we are about to compute using the + // area per thread determined by the caller. n_j is the number of + // columns in the remaining region of the matrix being partitioned, + // and diagoff_j is that region's diagonal offset. + + // If this is the last subpartition, the width is simply equal to n_j. + // Note that this statement handles cases where the "edge case" (if + // one exists) is assigned to the high end of the index range (ie: + // handle_edge_low == FALSE). + if ( j == n_way - 1 ) return n_j; + + // At this point, we know there are at least two subpartitions left. + // We also know that IF the submatrix contains a completely dense + // rectangular submatrix, it will occur BEFORE the triangular (or + // trapezoidal) part. + + // Here, we implement a somewhat minor load balancing optimization + // that ends up getting employed only for relatively small matrices. + // First, recall that all subpartition widths will be some multiple + // of the blocking factor bf, except perhaps either the first or last + // subpartition, which will receive the edge case, if it exists. + // Also recall that j represents the current thread (or thread group, + // or "caucus") for which we are computing a subpartition width. + // If n_j is sufficiently small that we can only allocate bf columns + // to each of the remaining threads, then we set the width to bf. We + // do not allow the subpartition width to be less than bf, so, under + // some conditions, if n_j is small enough, some of the reamining + // threads may not get any work. For the purposes of this lower bound + // on work (ie: width >= bf), we allow the edge case to count as a + // "full" set of bf columns. + { + dim_t n_j_bf = n_j / bf + ( bf_left > 0 ? 1 : 0 ); + + if ( n_j_bf <= n_way - j ) + { + if ( j == 0 && handle_edge_low ) + width = ( bf_left > 0 ? bf_left : bf ); + else + width = bf; + + // Make sure that the width does not exceed n_j. This would + // occur if and when n_j_bf < n_way - j; that is, when the + // matrix being partitioned is sufficiently small relative to + // n_way such that there is not even enough work for every + // (remaining) thread to get bf (or bf_left) columns. The + // net effect of this safeguard is that some threads may get + // assigned empty ranges (ie: no work), which of course must + // happen in some situations. + if ( width > n_j ) width = n_j; + + return width; + } + } + + // This block computes the width assuming that we are entirely within + // a dense rectangle that precedes the triangular (or trapezoidal) + // part. + { + // First compute the width of the current panel under the + // assumption that the diagonal offset would not intersect. + width = ( dim_t )bli_round( ( double )area_per_thr / ( double )m ); + + // Adjust the width, if necessary. Specifically, we may need + // to allocate the edge case to the first subpartition, if + // requested; otherwise, we just need to ensure that the + // subpartition is a multiple of the blocking factor. + if ( j == 0 && handle_edge_low ) + { + if ( width % bf != bf_left ) width += bf_left - ( width % bf ); + } + else // if interior case + { + // Round up to the next multiple of the blocking factor. + //if ( width % bf != 0 ) width += bf - ( width % bf ); + // Round to the nearest multiple of the blocking factor. + if ( width % bf != 0 ) width = bli_round_to_mult( width, bf ); + } + } + + // We need to recompute width if the panel, according to the width + // as currently computed, would intersect the diagonal. + if ( diagoff_j < width ) + { + dim_t offm_inc, offn_inc; + + // Prune away the unstored region above the diagonal, if it exists. + // Note that the entire region was pruned initially, so we know that + // we don't need to try to prune the right side. (Also, we discard + // the offset deltas since we don't need to actually index into the + // subpartition.) + bli_prune_unstored_region_top_l( &diagoff_j, &m, &n_j, &offm_inc ); + //bli_prune_unstored_region_right_l( &diagoff_j, &m, &n_j, &offn_inc ); + + // We don't need offm_inc, offn_inc here. These statements should + // prevent compiler warnings. + ( void )offm_inc; + ( void )offn_inc; + + // Prepare to solve a quadratic equation to find the width of the + // current (jth) subpartition given the m dimension, diagonal offset, + // and area. + // NOTE: We know that the +/- in the quadratic formula must be a + + // here because we know that the desired solution (the subpartition + // width) will be smaller than (m + diagoff), not larger. If you + // don't believe me, draw a picture! + const double a = -0.5; + const double b = ( double )m + ( double )diagoff_j + 0.5; + const double c = -0.5 * ( ( double )diagoff_j * + ( ( double )diagoff_j + 1.0 ) + ) - area_per_thr; + const double r = b * b - 4.0 * a * c; + + // If the quadratic solution is not imaginary, round it and use that + // as our width (but make sure it didn't round to zero). Otherwise, + // discard the quadratic solution and leave width, as previously + // computed, unchanged. + if ( r >= 0.0 ) + { + const double x = ( -b + sqrt( r ) ) / ( 2.0 * a ); + + width = ( dim_t )bli_round( x ); + if ( width == 0 ) width = 1; + } + + // Adjust the width, if necessary. + if ( j == 0 && handle_edge_low ) + { + if ( width % bf != bf_left ) width += bf_left - ( width % bf ); + } + else // if interior case + { + // Round up to the next multiple of the blocking factor. + //if ( width % bf != 0 ) width += bf - ( width % bf ); + // Round to the nearest multiple of the blocking factor. + if ( width % bf != 0 ) width = bli_round_to_mult( width, bf ); + } + } + + // Make sure that the width, after being adjusted, does not cause the + // subpartition to exceed n_j. + if ( width > n_j ) width = n_j; + + return width; +} + +siz_t bli_find_area_trap_l + ( + doff_t diagoff, + dim_t m, + dim_t n, + dim_t bf + ) +{ + dim_t offm_inc = 0; + dim_t offn_inc = 0; + double utri_area; + double blktri_area; + + // Prune away any rectangular region above where the diagonal + // intersects the left edge of the subpartition, if it exists. + bli_prune_unstored_region_top_l( &diagoff, &m, &n, &offm_inc ); + + // Prune away any rectangular region to the right of where the + // diagonal intersects the bottom edge of the subpartition, if + // it exists. (This shouldn't ever be needed, since the caller + // would presumably have already performed rightward pruning, + // but it's here just in case.) + //bli_prune_unstored_region_right_l( &diagoff, &m, &n, &offn_inc ); + + ( void )offm_inc; + ( void )offn_inc; + + // Compute the area of the empty triangle so we can subtract it + // from the area of the rectangle that bounds the subpartition. + if ( bli_intersects_diag_n( diagoff, m, n ) ) + { + double tri_dim = ( double )( n - diagoff - 1 ); + tri_dim = bli_min( tri_dim, m - 1 ); + + utri_area = tri_dim * ( tri_dim + 1.0 ) / 2.0; + blktri_area = tri_dim * ( bf - 1.0 ) / 2.0; + } + else + { + // If the diagonal does not intersect the trapezoid, then + // we can compute the area as a simple rectangle. + utri_area = 0.0; + blktri_area = 0.0; + } + + double area = ( double )m * ( double )n - utri_area + blktri_area; + + return ( siz_t )area; +} + +// ----------------------------------------------------------------------------- + +siz_t bli_thread_range_weighted_sub + ( + const thrinfo_t* thread, + doff_t diagoff, + uplo_t uplo, + uplo_t uplo_orig, + dim_t m, + dim_t n, + dim_t bf, + bool handle_edge_low, + dim_t* j_start_thr, + dim_t* j_end_thr + ) +{ + dim_t n_way = bli_thrinfo_n_way( thread ); + dim_t my_id = bli_thrinfo_work_id( thread ); + + dim_t bf_left = n % bf; + + dim_t offm_inc, offn_inc; + + siz_t area = 0; + + // In this function, we assume that the caller has already determined + // that (a) the diagonal intersects the submatrix, and (b) the submatrix + // is either lower- or upper-stored. + + if ( bli_is_lower( uplo ) ) + { + #if 0 + if ( n_way > 1 ) + printf( "thread_range_weighted_sub(): tid %d: m n = %3d %3d do %d (lower)\n", + (int)my_id, (int)(m), (int)(n), (int)(diagoff) ); + #endif + + // Prune away the unstored region above the diagonal, if it exists, + // and then to the right of where the diagonal intersects the bottom, + // if it exists. (Also, we discard the offset deltas since we don't + // need to actually index into the subpartition.) + bli_prune_unstored_region_top_l( &diagoff, &m, &n, &offm_inc ); + + if ( !handle_edge_low ) + { + // This branch handles the following two cases: + // - note: Edge case microtiles are marked as 'e'. + // + // uplo_orig = lower | uplo = lower + // handle edge high (orig) | handle edge high + // + // x x x x x x x x x x x x x x + // x x x x x x x x x x x x x x x x + // x x x x x x x x x -> x x x x x x x x x + // x x x x x x x x x x x x x x x x x x x x + // x x x x x x x x x x e x x x x x x x x x x e + // x x x x x x x x x x e x x x x x x x x x x e + // + // uplo_orig = upper | uplo = lower + // handle edge low (orig) | handle edge high + // + // e x x x x x x x x x x x x x x x x x + // e x x x x x x x x x x x x x x x x x x + // x x x x x x x x x x -> x x x x x x x x x + // x x x x x x x x x x x x x x x x x x x + // x x x x x x x x x x x x x x x x x x e + // x x x x x x x x x x x x x x x x x e + + // If the edge case is being handled "high", then we can employ this + // simple macro for pruning the region to the right of where the + // diagonal intersets the right side of the submatrix (which amounts + // to adjusting the n dimension). + bli_prune_unstored_region_right_l( &diagoff, &m, &n, &offn_inc ); + } + else // if ( handle_edge_low ) + { + // This branch handles the following two cases: + // + // uplo_orig = upper | uplo = lower + // handle edge high (orig) | handle edge low + // + // x x x x x x x x x x e e x x x x x x + // x x x x x x x x x x e e x x x x x x x + // x x x x x x x x x e -> e x x x x x x x x + // x x x x x x x x e e x x x x x x x x x + // x x x x x x x e e x x x x x x x x x x + // x x x x x x e e x x x x x x x x x x + // + // uplo_orig = lower | uplo = lower + // handle edge low (orig) | handle edge low + // + // e x x x x x x e x x x x x x + // e x x x x x x x e x x x x x x x + // e x x x x x x x x -> e x x x x x x x x + // e x x x x x x x x x e x x x x x x x x x + // e x x x x x x x x x x e x x x x x x x x x x + // e x x x x x x x x x x e x x x x x x x x x x + + // If the edge case is being handled "low", then we have to be more + // careful. The problem can be seen in certain situations when we're + // actually computing the weighted ranges for an upper-stored + // subpartition whose (a) diagonal offset is positive (though will + // always be less than NR), (b) right-side edge case exists, and (c) + // sum of (a) and (b) is less than NR. This is a problem because the + // upcoming loop that iterates over/ bli_thread_range_width_l() + // doesn't realize that the offsets associated with (a) and (b) + // belong on two separate columns of microtiles. If we naively use + // bli_prune_unstored_region_right_l() when handle_edge_low == TRUE, + // the loop over bli_thread_range_width_l() will only "see" p-1 + // IR-iterations of work to assign to threads when there are + // actually p micropanels. + + const dim_t n_inner = ( diagoff + bli_min( m, n - diagoff ) - bf_left ); + + const dim_t n_bf_iter_br = n_inner / bf; + const dim_t n_bf_left_br = n_inner % bf; + const dim_t n_bf_br = ( bf_left > 0 ? 1 : 0 ) + + n_bf_iter_br + + ( n_bf_left_br > 0 ? 1 : 0 ); + + // Compute the number of extra columns that were included in n_bf_br + // as a result of including a full micropanel for the part of the + // submatrix that contains bf_left columns. For example, if bf = 16 + // and bf_left = 4, then bf_extra = 12. But if bf_left = 0, then we + // didn't include any extra columns. + const dim_t bf_extra = ( bf_left > 0 ? bf - bf_left : 0 ); + + // Subtract off bf_extra from n_bf_br to arrive at the "true" value + // of n that we'll use going forward. + n = n_bf_br * bf - bf_extra; + + #if 0 + if ( n_way > 1 ) + { + //printf( "thread_range_weighted_sub(): tid %d: _iter _left = %3d %3d (lower1)\n", + // (int)my_id, (int)n_bf_iter_br, (int)n_bf_left_br ); + printf( "thread_range_weighted_sub(): tid %d: m n = %3d %3d do %d (lower2)\n", + (int)my_id, (int)(m), (int)(n), (int)(diagoff) ); + } + #endif + } + + // We don't need offm_inc, offn_inc here. These statements should + // prevent compiler warnings. + ( void )offm_inc; + ( void )offn_inc; + + // Now that pruning has taken place, we know that diagoff >= 0. + + // Compute the total area of the submatrix, accounting for the + // location of the diagonal. This is done by computing the area in + // the strictly upper triangle, subtracting it off the area of the + // full rectangle, and then adding the missing strictly upper + // triangles of the bf x bf blocks along the diagonal. + double tri_dim = ( double )( n - diagoff - 1 ); + tri_dim = bli_min( tri_dim, m - 1 ); + double utri_area = tri_dim * ( tri_dim + 1.0 ) / 2.0; + + // Note that the expression below is the simplified form of: + // blktri_area = ( tri_dim / bf ) * bf * ( bf - 1.0 ) / 2.0; + double blktri_area = tri_dim * ( bf - 1.0 ) / 2.0; + + // Compute the area of the region to the right of where the diagonal + // intersects the bottom edge of the submatrix. If it instead intersects + // the right edge (or the bottom-right corner), then this region does + // not exist and so its area is explicitly set to zero. + double beyondtri_dim = n - diagoff - m; + double beyondtri_area; + if ( 0 < beyondtri_dim ) beyondtri_area = beyondtri_dim * m; + else beyondtri_area = 0.0; + + // Here, we try to account for the added cost of computing columns of + // microtiles that intersect the diagonal. This is rather difficult to + // model, but this is partly due to the way non-square microtiles map + // onto the matrix relative to the diagonal, as well as additional + // overhead incurred from (potentially) computing with less-than-full + // columns of microtiles (i.e., columns for which diagoff_j < 0). + // Note that higher values for blktri_area have the net effect of + // increasing the relative size of slabs that share little or no overlap + // with the diagonal region. this is because it slightly increases the + // total area computation below, which in turn increases the area + // targeted by each thread/group earlier in the thread range, which + // for lower trapezoidal submatrices, corresponds to the regular + // rectangular region that precedes the diagonal part (if such a + // rectangular region exists). + blktri_area *= 1.5; + //blktri_area = 0.0; + + double area_total = ( double )m * ( double )n - utri_area + blktri_area + - beyondtri_area; + + // Divide the computed area by the number of ways of parallelism. + double area_per_thr = area_total / ( double )n_way; + + + // Initialize some variables prior to the loop: the offset to the + // current subpartition, the remainder of the n dimension, and + // the diagonal offset of the current subpartition. + dim_t off_j = 0; + doff_t diagoff_j = diagoff; + dim_t n_left = n; + + #if 0 + printf( "thread_range_weighted_sub(): tid %d: n_left = %3d (lower4)\n", + (int)my_id, (int)(n_left) ); + #endif + + // Iterate over the subpartition indices corresponding to each + // thread/caucus participating in the n_way parallelism. + for ( dim_t j = 0; j < n_way; ++j ) + { + // Compute the width of the jth subpartition, taking the + // current diagonal offset into account, if needed. + dim_t width_j + = + bli_thread_range_width_l + ( + diagoff_j, m, n_left, + j, n_way, + bf, bf_left, + area_per_thr, + handle_edge_low + ); + + #if 0 + if ( n_way > 1 ) + printf( "thread_range_weighted_sub(): tid %d: width_j = %d doff_j = %d\n", + (int)my_id, (int)width_j, (int)diagoff_j ); + #endif + + // If the current thread belongs to caucus j, this is his + // subpartition. So we compute the implied index range and + // end our search. + #if 0 + // An alternate way of assigning work to threads such that regions + // are assigned to threads left to right *after* accounting for the + // fact that we recycle the same lower-trapezoidal code to also + // compute the upper-trapezoidal case. + bool is_my_range; + if ( bli_is_lower( uplo_orig ) ) is_my_range = ( j == my_id ); + else is_my_range = ( j == n_way - my_id - 1 ); + #else + bool is_my_range = ( j == my_id ); + #endif + + if ( is_my_range ) + { + *j_start_thr = off_j; + *j_end_thr = off_j + width_j; + + #if 0 + if ( n_way > 1 ) + printf( "thread_range_weighted_sub(): tid %d: sta end = %3d %3d\n", + (int)my_id, (int)(*j_start_thr), (int)(*j_end_thr) ); + //printf( "thread_range_weighted_sub(): tid %d: n_left = %3d\n", + // (int)my_id, (int)(n) ); + #endif + + // Compute the area of the thread's current subpartition in case + // the caller is curious how much work they were assigned. + // NOTE: This area computation isn't actually needed for BLIS to + // function properly.) + area = bli_find_area_trap_l( diagoff_j, m, width_j, bf ); + + break; + } + + // Shift the current subpartition's starting and diagonal offsets, + // as well as the remainder of the n dimension, according to the + // computed width, and then iterate to the next subpartition. + off_j += width_j; + diagoff_j -= width_j; + n_left -= width_j; + } + } + else // if ( bli_is_upper( uplo ) ) + { + // Express the upper-stored case in terms of the lower-stored case. + + #if 0 + if ( n_way > 1 ) + printf( "thread_range_weighted_sub(): tid %d: m n = %3d %3d do %d (upper)\n", + (int)my_id, (int)(m), (int)(n), (int)(diagoff) ); + #endif + + // First, we convert the upper-stored trapezoid to an equivalent + // lower-stored trapezoid by rotating it 180 degrees. + bli_rotate180_trapezoid( &diagoff, &uplo, &m, &n ); + + // Now that the trapezoid is "flipped" in the n dimension, negate + // the bool that encodes whether to handle the edge case at the + // low (or high) end of the index range. + bli_toggle_bool( &handle_edge_low ); + + // Compute the appropriate range for the rotated trapezoid. + area = bli_thread_range_weighted_sub + ( + thread, diagoff, uplo, uplo_orig, m, n, bf, + handle_edge_low, + j_start_thr, j_end_thr + ); + + // Reverse the indexing basis for the subpartition ranges so that + // the indices, relative to left-to-right iteration through the + // unrotated upper-stored trapezoid, map to the correct columns + // (relative to the diagonal). This amounts to subtracting the + // range from n. + bli_reverse_index_direction( n, j_start_thr, j_end_thr ); + } + + return area; +} + +// ----------------------------------------------------------------------------- + +siz_t bli_thread_range_mdim + ( + dir_t direct, + const thrinfo_t* thr, + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntl_t* cntl, + const cntx_t* cntx, + dim_t* start, + dim_t* end + ) +{ + bszid_t bszid = bli_cntl_bszid( cntl ); + opid_t family = bli_cntl_family( cntl ); + + // This is part of trsm's current implementation, whereby right side + // cases are implemented in left-side micro-kernels, which requires + // we swap the usage of the register blocksizes for the purposes of + // packing A and B. + if ( family == BLIS_TRSM ) + { + if ( bli_obj_root_is_triangular( a ) ) bszid = BLIS_MR; + else bszid = BLIS_NR; + } + + const blksz_t* bmult = bli_cntx_get_bmult( bszid, cntx ); + const obj_t* x; + bool use_weighted; + + // Use the operation family to choose the one of the two matrices + // being partitioned that potentially has structure, and also to + // decide whether or not we need to use weighted range partitioning. + // NOTE: It's important that we use non-weighted range partitioning + // for hemm and symm (ie: the gemm family) because the weighted + // function will mistakenly skip over unstored regions of the + // structured matrix, even though they represent part of that matrix + // that will be dense and full (after packing). + if ( family == BLIS_GEMM ) { x = a; use_weighted = FALSE; } + else if ( family == BLIS_GEMMT ) { x = c; use_weighted = TRUE; } + else if ( family == BLIS_TRMM ) { x = a; use_weighted = TRUE; } + else /*family == BLIS_TRSM*/ { x = a; use_weighted = FALSE; } + + if ( use_weighted ) + { + if ( direct == BLIS_FWD ) + return bli_thread_range_weighted_t2b( thr, x, bmult, start, end ); + else + return bli_thread_range_weighted_b2t( thr, x, bmult, start, end ); + } + else + { + if ( direct == BLIS_FWD ) + return bli_thread_range_t2b( thr, x, bmult, start, end ); + else + return bli_thread_range_b2t( thr, x, bmult, start, end ); + } +} + +siz_t bli_thread_range_ndim + ( + dir_t direct, + const thrinfo_t* thr, + const obj_t* a, + const obj_t* b, + const obj_t* c, + const cntl_t* cntl, + const cntx_t* cntx, + dim_t* start, + dim_t* end + ) +{ + bszid_t bszid = bli_cntl_bszid( cntl ); + opid_t family = bli_cntl_family( cntl ); + + // This is part of trsm's current implementation, whereby right side + // cases are implemented in left-side micro-kernels, which requires + // we swap the usage of the register blocksizes for the purposes of + // packing A and B. + if ( family == BLIS_TRSM ) + { + if ( bli_obj_root_is_triangular( b ) ) bszid = BLIS_MR; + else bszid = BLIS_NR; + } + + const blksz_t* bmult = bli_cntx_get_bmult( bszid, cntx ); + const obj_t* x; + bool use_weighted; + + // Use the operation family to choose the one of the two matrices + // being partitioned that potentially has structure, and also to + // decide whether or not we need to use weighted range partitioning. + // NOTE: It's important that we use non-weighted range partitioning + // for hemm and symm (ie: the gemm family) because the weighted + // function will mistakenly skip over unstored regions of the + // structured matrix, even though they represent part of that matrix + // that will be dense and full (after packing). + if ( family == BLIS_GEMM ) { x = b; use_weighted = FALSE; } + else if ( family == BLIS_GEMMT ) { x = c; use_weighted = TRUE; } + else if ( family == BLIS_TRMM ) { x = b; use_weighted = TRUE; } + else /*family == BLIS_TRSM*/ { x = b; use_weighted = FALSE; } + + if ( use_weighted ) + { + if ( direct == BLIS_FWD ) + return bli_thread_range_weighted_l2r( thr, x, bmult, start, end ); + else + return bli_thread_range_weighted_r2l( thr, x, bmult, start, end ); + } + else + { + if ( direct == BLIS_FWD ) + return bli_thread_range_l2r( thr, x, bmult, start, end ); + else + return bli_thread_range_r2l( thr, x, bmult, start, end ); + } +} + +// ----------------------------------------------------------------------------- + +siz_t bli_thread_range_weighted_l2r + ( + const thrinfo_t* thr, + const obj_t* a, + const blksz_t* bmult, + dim_t* start, + dim_t* end + ) +{ + siz_t area; + + // This function assigns area-weighted ranges in the n dimension + // where the total range spans 0 to n-1 with 0 at the left end and + // n-1 at the right end. + + if ( bli_obj_intersects_diag( a ) && + bli_obj_is_upper_or_lower( a ) ) + { + num_t dt = bli_obj_dt( a ); + doff_t diagoff = bli_obj_diag_offset( a ); + uplo_t uplo = bli_obj_uplo( a ); + dim_t m = bli_obj_length( a ); + dim_t n = bli_obj_width( a ); + dim_t bf = bli_blksz_get_def( dt, bmult ); + + // Support implicit transposition. + if ( bli_obj_has_trans( a ) ) + { + bli_reflect_about_diag( &diagoff, &uplo, &m, &n ); + } + + area = + bli_thread_range_weighted_sub + ( + thr, diagoff, uplo, uplo, m, n, bf, + FALSE, start, end + ); + } + else // if dense or zeros + { + area = bli_thread_range_l2r + ( + thr, a, bmult, + start, end + ); + } + + return area; +} + +siz_t bli_thread_range_weighted_r2l + ( + const thrinfo_t* thr, + const obj_t* a, + const blksz_t* bmult, + dim_t* start, + dim_t* end + ) +{ + siz_t area; + + // This function assigns area-weighted ranges in the n dimension + // where the total range spans 0 to n-1 with 0 at the right end and + // n-1 at the left end. + + if ( bli_obj_intersects_diag( a ) && + bli_obj_is_upper_or_lower( a ) ) + { + num_t dt = bli_obj_dt( a ); + doff_t diagoff = bli_obj_diag_offset( a ); + uplo_t uplo = bli_obj_uplo( a ); + dim_t m = bli_obj_length( a ); + dim_t n = bli_obj_width( a ); + dim_t bf = bli_blksz_get_def( dt, bmult ); + + // Support implicit transposition. + if ( bli_obj_has_trans( a ) ) + { + bli_reflect_about_diag( &diagoff, &uplo, &m, &n ); + } + + bli_rotate180_trapezoid( &diagoff, &uplo, &m, &n ); + + area = + bli_thread_range_weighted_sub + ( + thr, diagoff, uplo, uplo, m, n, bf, + TRUE, start, end + ); + } + else // if dense or zeros + { + area = bli_thread_range_r2l + ( + thr, a, bmult, + start, end + ); + } + + return area; +} + +siz_t bli_thread_range_weighted_t2b + ( + const thrinfo_t* thr, + const obj_t* a, + const blksz_t* bmult, + dim_t* start, + dim_t* end + ) +{ + siz_t area; + + // This function assigns area-weighted ranges in the m dimension + // where the total range spans 0 to m-1 with 0 at the top end and + // m-1 at the bottom end. + + if ( bli_obj_intersects_diag( a ) && + bli_obj_is_upper_or_lower( a ) ) + { + num_t dt = bli_obj_dt( a ); + doff_t diagoff = bli_obj_diag_offset( a ); + uplo_t uplo = bli_obj_uplo( a ); + dim_t m = bli_obj_length( a ); + dim_t n = bli_obj_width( a ); + dim_t bf = bli_blksz_get_def( dt, bmult ); + + // Support implicit transposition. + if ( bli_obj_has_trans( a ) ) + { + bli_reflect_about_diag( &diagoff, &uplo, &m, &n ); + } + + bli_reflect_about_diag( &diagoff, &uplo, &m, &n ); + + area = + bli_thread_range_weighted_sub + ( + thr, diagoff, uplo, uplo, m, n, bf, + FALSE, start, end + ); + } + else // if dense or zeros + { + area = bli_thread_range_t2b + ( + thr, a, bmult, + start, end + ); + } + + return area; +} + +siz_t bli_thread_range_weighted_b2t + ( + const thrinfo_t* thr, + const obj_t* a, + const blksz_t* bmult, + dim_t* start, + dim_t* end + ) +{ + siz_t area; + + // This function assigns area-weighted ranges in the m dimension + // where the total range spans 0 to m-1 with 0 at the bottom end and + // m-1 at the top end. + + if ( bli_obj_intersects_diag( a ) && + bli_obj_is_upper_or_lower( a ) ) + { + num_t dt = bli_obj_dt( a ); + doff_t diagoff = bli_obj_diag_offset( a ); + uplo_t uplo = bli_obj_uplo( a ); + dim_t m = bli_obj_length( a ); + dim_t n = bli_obj_width( a ); + dim_t bf = bli_blksz_get_def( dt, bmult ); + + // Support implicit transposition. + if ( bli_obj_has_trans( a ) ) + { + bli_reflect_about_diag( &diagoff, &uplo, &m, &n ); + } + + bli_reflect_about_diag( &diagoff, &uplo, &m, &n ); + + bli_rotate180_trapezoid( &diagoff, &uplo, &m, &n ); + + area = bli_thread_range_weighted_sub + ( + thr, diagoff, uplo, uplo, m, n, bf, + TRUE, start, end + ); + } + else // if dense or zeros + { + area = bli_thread_range_b2t + ( + thr, a, bmult, + start, end + ); + } + + return area; +} + diff --git a/frame/thread/bli_thread_range.h b/frame/thread/bli_thread_range.h new file mode 100644 index 0000000000..cf966b5a35 --- /dev/null +++ b/frame/thread/bli_thread_range.h @@ -0,0 +1,128 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2016, Hewlett Packard Enterprise Development LP + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_THREAD_RANGE_H +#define BLIS_THREAD_RANGE_H + +// Thread range-related prototypes. + +BLIS_EXPORT_BLIS void bli_thread_range_sub + ( + const thrinfo_t* thread, + dim_t n, + dim_t bf, + bool handle_edge_low, + dim_t* start, + dim_t* end + ); + +#undef GENPROT +#define GENPROT( opname ) \ +\ +siz_t PASTEMAC0( opname ) \ + ( \ + dir_t direct, \ + const thrinfo_t* thr, \ + const obj_t* a, \ + const obj_t* b, \ + const obj_t* c, \ + const cntl_t* cntl, \ + const cntx_t* cntx, \ + dim_t* start, \ + dim_t* end \ + ); + +GENPROT( thread_range_mdim ) +GENPROT( thread_range_ndim ) + +#undef GENPROT +#define GENPROT( opname ) \ +\ +siz_t PASTEMAC0( opname ) \ + ( \ + const thrinfo_t* thr, \ + const obj_t* a, \ + const blksz_t* bmult, \ + dim_t* start, \ + dim_t* end \ + ); + +GENPROT( thread_range_l2r ) +GENPROT( thread_range_r2l ) +GENPROT( thread_range_t2b ) +GENPROT( thread_range_b2t ) + +GENPROT( thread_range_weighted_l2r ) +GENPROT( thread_range_weighted_r2l ) +GENPROT( thread_range_weighted_t2b ) +GENPROT( thread_range_weighted_b2t ) + + +dim_t bli_thread_range_width_l + ( + doff_t diagoff_j, + dim_t m, + dim_t n_j, + dim_t j, + dim_t n_way, + dim_t bf, + dim_t bf_left, + double area_per_thr, + bool handle_edge_low + ); +siz_t bli_find_area_trap_l + ( + doff_t diagoff, + dim_t m, + dim_t n, + dim_t bf + ); + +siz_t bli_thread_range_weighted_sub + ( + const thrinfo_t* thread, + doff_t diagoff, + uplo_t uplo, + uplo_t uplo_orig, + dim_t m, + dim_t n, + dim_t bf, + bool handle_edge_low, + dim_t* j_start_thr, + dim_t* j_end_thr + ); + +#endif diff --git a/frame/thread/bli_thread_range_slab_rr.c b/frame/thread/bli_thread_range_slab_rr.c new file mode 100644 index 0000000000..be44323096 --- /dev/null +++ b/frame/thread/bli_thread_range_slab_rr.c @@ -0,0 +1,134 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bli_thread_range_quad + ( + const thrinfo_t* thread, + doff_t diagoff, + uplo_t uplo, + dim_t m, + dim_t n, + dim_t bf, + bool handle_edge_low, + dim_t* start, + dim_t* end, + dim_t* inc + ) +{ + +#ifdef BLIS_ENABLE_JRIR_RR + + const dim_t tid = bli_thrinfo_work_id( thread ); + const dim_t jr_nt = bli_thrinfo_n_way( thread ); + const dim_t n_iter = n / bf + ( n % bf ? 1 : 0 ); + + // Use round-robin (interleaved) partitioning of jr/ir loops. + *start = tid; + *end = n_iter; + *inc = jr_nt; + +#else // #elif defined( BLIS_ENABLE_JRIR_SLAB ) || + // defined( BLIS_ENABLE_JRIR_TLB ) + + // NOTE: While this cpp conditional branch applies to both _SLAB and _TLB + // cases, this *function* should never be called when BLIS_ENABLE_JRIR_TLB + // is defined, since the function is only called from macrokernels that were + // designed for slab/rr partitioning. + + const dim_t jr_nt = bli_thrinfo_n_way( thread ); + const dim_t n_iter = n / bf + ( n % bf ? 1 : 0 ); + + // If there is no parallelism in this loop, set the output variables + // and return early. + if ( jr_nt == 1 ) { *start = 0; *end = n_iter; *inc = 1; return; } + + // Local variables for the computed start, end, and increment. + dim_t st, en, in; + + if ( bli_intersects_diag_n( diagoff, m, n ) ) + { + // If the current submatrix intersects the diagonal, try to be + // intelligent about how threads are assigned work by using the + // quadratic partitioning function. + + bli_thread_range_weighted_sub + ( + thread, diagoff, uplo, uplo, m, n, bf, + handle_edge_low, &st, &en + ); + in = bf; + } + else + { + // If the current submatrix does not intersect the diagonal, then we + // are free to perform a uniform (and contiguous) slab partitioning. + + bli_thread_range_sub + ( + thread, n, bf, + handle_edge_low, &st, &en + ); + in = bf; + } + + // Convert the start and end column indices into micropanel indices by + // dividing by the blocking factor (which, for the jr loop, is NR). If + // either one yields a remainder, add an extra unit to the result. This + // is necessary for situations where there are t threads with t-1 or + // fewer micropanels of work, including an edge case. For example, if + // t = 3 and n = 10 (with bf = NR = 8), then we want start and end for + // each thread to be: + // + // column index upanel index + // tid 0: start, end = 0, 8 -> start, end = 0, 1 + // tid 1: start, end = 8, 10 -> start, end = 1, 2 + // tid 2: start, end = 10, 10 -> start, end = 2, 2 + // + // In this example, it's important that thread (tid) 2 gets no work, and + // we express that by specifying start = end = n, which is a non-existent + // column index. + + if ( st % bf == 0 ) *start = st / bf; + else *start = st / bf + 1; + + if ( en % bf == 0 ) *end = en / bf; + else *end = en / bf + 1; + + *inc = in / bf; + +#endif +} diff --git a/frame/thread/bli_thread_range_slab_rr.h b/frame/thread/bli_thread_range_slab_rr.h new file mode 100644 index 0000000000..3e9797363b --- /dev/null +++ b/frame/thread/bli_thread_range_slab_rr.h @@ -0,0 +1,116 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_THREAD_RANGE_SLAB_RR_H +#define BLIS_THREAD_RANGE_SLAB_RR_H + +BLIS_INLINE void bli_thread_range_rr + ( + const thrinfo_t* thread, + dim_t n, + dim_t bf, + bool handle_edge_low, + dim_t* start, + dim_t* end, + dim_t* inc + ) +{ + const dim_t tid = bli_thrinfo_work_id( thread ); + const dim_t nt = bli_thrinfo_n_way( thread ); + const dim_t n_iter = n / bf + ( n % bf ? 1 : 0 ); + + // Use round-robin (interleaved) partitioning of jr/ir loops. + *start = tid; + *end = n_iter; + *inc = nt; +} + +BLIS_INLINE void bli_thread_range_sl + ( + const thrinfo_t* thread, + dim_t n, + dim_t bf, + bool handle_edge_low, + dim_t* start, + dim_t* end, + dim_t* inc + ) +{ + // Use contiguous slab partitioning of jr/ir loops. + bli_thread_range_sub( thread, n, bf, handle_edge_low, start, end ); + *inc = 1; +} + +BLIS_INLINE void bli_thread_range_slrr + ( + const thrinfo_t* thread, + dim_t n, + dim_t bf, + bool handle_edge_low, + dim_t* start, + dim_t* end, + dim_t* inc + ) +{ + // Define a general-purpose slab/rr function whose definition depends on + // whether slab or round-robin partitioning was requested at configure-time. + // Note that this function also uses the slab code path when tlb is enabled. + // If this is ever changed, make sure to change bli_is_my_iter() since they + // are used together by packm. + +#ifdef BLIS_ENABLE_JRIR_RR + bli_thread_range_rr( thread, n, bf, handle_edge_low, start, end, inc ); +#else // ifdef ( _SLAB || _TLB ) + bli_thread_range_sl( thread, n, bf, handle_edge_low, start, end, inc ); +#endif +} + +// ----------------------------------------------------------------------------- + +void bli_thread_range_quad + ( + const thrinfo_t* thread, + doff_t diagoff, + uplo_t uplo, + dim_t m, + dim_t n, + dim_t bf, + bool handle_edge_low, + dim_t* start, + dim_t* end, + dim_t* inc + ); + +#endif + diff --git a/frame/thread/bli_thread_range_tlb.c b/frame/thread/bli_thread_range_tlb.c new file mode 100644 index 0000000000..546ed341d6 --- /dev/null +++ b/frame/thread/bli_thread_range_tlb.c @@ -0,0 +1,1699 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// ----------------------------------------------------------------------------- + +#define PRINT_MODE +#define PGUARD if ( 0 ) +//#define PRINT_RESULT + + +#if 0 +dim_t bli_thread_range_tlb + ( + const dim_t nt, + const dim_t tid, + const doff_t diagoff, + const uplo_t uplo, + const dim_t m_iter, + const dim_t n_iter, + const dim_t mr, + const dim_t nr, + inc_t* j_st_p, + inc_t* i_st_p + ) +{ + dim_t n_ut_for_me; + + if ( bli_is_lower( uplo ) ) + { + n_ut_for_me = bli_thread_range_tlb_l + ( + nt, tid, diagoff, m_iter, n_iter, mr, nr, j_st_p, i_st_p + ); + } + else if ( bli_is_upper( uplo ) ) + { + n_ut_for_me = bli_thread_range_tlb_u + ( + nt, tid, diagoff, m_iter, n_iter, mr, nr, j_st_p, i_st_p + ); + } + else // if ( bli_is_dense( uplo ) ) + { + n_ut_for_me = bli_thread_range_tlb_d + ( + nt, tid, m_iter, n_iter, mr, nr, j_st_p, i_st_p + ); + } + + return n_ut_for_me; +} +#endif + +// ----------------------------------------------------------------------------- + +dim_t bli_thread_range_tlb_l + ( + const dim_t nt, + const dim_t tid, + const doff_t diagoff, + const dim_t m_iter, + const dim_t n_iter, + const dim_t mr, + const dim_t nr, + inc_t* j_st_p, + inc_t* i_st_p + ) +{ + // This function implements tile-level load balancing for a + // lower-trapezoidal submatrix. This partitioning guarantees that all + // threads are assigned nearly the same number of microtiles-worth of work, + // with a maximum imbalance of one microtile. It makes no effort, however, + // to account for differences in threads' workload that is attributable to + // differences in the number of edge-case (or diagonal-intersecting) + // microtiles (which incur slightly more work since they must first write + // to a temporary microtile before updating the output C matrix). + + // Assumption: -mr < diagoff. Make sure to prune leading rows beforehand! + if ( diagoff <= -mr ) bli_abort(); + + // + // -- Step 1: Compute the computational area of the region ----------------- + // + + // Compute the m and n dimensions according to m_iter and n_iter. (These + // m and n dims will likely be larger than the actual m and n since they + // "round up" the edge case microtiles into full-sized microtiles.) + const dim_t m = m_iter * mr; + const dim_t n = n_iter * nr; + + // For the purposes of many computations in this function, we aren't + // interested in the extent to which diagoff exceeds n (if it does) + // So we use a new variable that is guaranteed to be no greater than n. + const doff_t diagoffmin = bli_min( diagoff, n ); + + const dim_t m_rect = m; + const dim_t n_rect = ( diagoffmin / nr ) * nr; + + const dim_t rect_area = m_rect * n_rect; + const dim_t nonrect_area = m * n - rect_area; + + //const dim_t offn_rect = 0; + const dim_t offn_nonrect = n_rect; + const dim_t diagoff_nonrect = diagoffmin - n_rect; //diagoff % nr; + + const dim_t n_nonrect = n - n_rect; + + const dim_t offn_ut_nonrect = ( diagoffmin / nr ); + + PGUARD printf( "---------------------------\n" ); + PGUARD printf( "min(diagoff,n): %7ld\n", diagoffmin ); + PGUARD printf( "offn_ut_nonrect: %7ld\n", offn_ut_nonrect ); + PGUARD printf( "offn_nonrect: %7ld\n", offn_nonrect ); + PGUARD printf( "diagoff_nonrect: %7ld\n", diagoff_nonrect ); + PGUARD printf( "n_nonrect: %7ld\n", n_nonrect ); + PGUARD printf( "---------------------------\n" ); + + dim_t num_unref_ut = 0; + + // Count the number of unreferenced utiles strictly above the diagonal. + for ( dim_t j = 0; j < n_nonrect; j += nr ) + { + const dim_t diagoff_j = diagoff_nonrect - j; + + // diagoff_j will always be at most nr - 1, but will typically be + // negative. This is because the non-rectangular region's diagonal + // offset will be at most nr - 1 for the first column of microtiles, + // since if it were more than nr - 1, that column would have already + // been pruned away (via the implicit pruning of diagoff_nonrect). + // NOTE: We use bli_max() to ensure that -diagoff_j / mr does not + // become negative, which can only happen if "top" pruning is not + // performed beforehand (and so it really isn't necessary here). + const dim_t num_unref_ut_j = bli_max( ( -diagoff_j / mr ), 0 ); + + num_unref_ut += num_unref_ut_j; + + PGUARD printf( "j %7ld\n", j ); + PGUARD printf( "diagoff_j %7ld\n", diagoff_j ); + PGUARD printf( "num_unref_ut_j %7ld\n", num_unref_ut_j ); + PGUARD printf( "num_unref_ut %7ld\n", num_unref_ut ); + PGUARD printf( "\n" ); + } + PGUARD printf( "---------------------------\n" ); + + const dim_t tri_unref_area = num_unref_ut * mr * nr; + const dim_t tri_ref_area = nonrect_area - tri_unref_area; + const dim_t total_ref_area = rect_area + tri_ref_area; + + PGUARD printf( "gross area: %7ld\n", m * n ); + PGUARD printf( "rect_area: %7ld\n", rect_area ); + PGUARD printf( "nonrect_area: %7ld\n", nonrect_area ); + PGUARD printf( "tri_unref_area: %7ld\n", tri_unref_area ); + PGUARD printf( "tri_ref_area: %7ld\n", tri_ref_area ); + PGUARD printf( "total_ref_area: %7ld\n", total_ref_area ); + PGUARD printf( "---------------------------\n" ); + + // + // -- Step 2: Compute key utile counts (per thread, per column, etc.) ------ + // + + const dim_t n_ut_ref = total_ref_area / ( mr * nr ); + //const dim_t n_ut_tri_ref = tri_ref_area / ( mr * nr ); + const dim_t n_ut_rect = rect_area / ( mr * nr ); + + PGUARD printf( "n_ut_ref: %7ld\n", n_ut_ref ); + //PGUARD printf( "n_ut_tri_ref: %7ld\n", n_ut_tri_ref ); + PGUARD printf( "n_ut_rect: %7ld\n", n_ut_rect ); + PGUARD printf( "---------------------------\n" ); + + // Compute the number of microtiles to allocate per thread as well as the + // number of leftover microtiles. + const dim_t n_ut_per_thr = n_ut_ref / nt; + const dim_t n_ut_pt_left = n_ut_ref % nt; + + PGUARD printf( "n_ut_per_thr: %7ld\n", n_ut_per_thr ); + PGUARD printf( "n_ut_pt_left: %7ld\n", n_ut_pt_left ); + PGUARD printf( "---------------------------\n" ); + + const dim_t n_ut_per_col = m_iter; + + PGUARD printf( "n_ut_per_col: %7ld\n", n_ut_per_col ); + + // Allocate one of the leftover microtiles to the current thread if its + // tid is one of the lower thread ids. + const dim_t n_ut_for_me = n_ut_per_thr + ( tid < n_ut_pt_left ? 1 : 0 ); + + PGUARD printf( "n_ut_for_me: %7ld (%ld+%ld)\n", n_ut_for_me, + n_ut_per_thr, n_ut_for_me - n_ut_per_thr ); + + // Compute the number of utiles prior to the current thread's starting + // point. This is the sum of all n_ut_for_me for all thread ids less + // than tid. Notice that the second half of this expression effectively + // adds one extra microtile for each lower-valued thread id, up to + // n_ut_pt_left. + const dim_t n_ut_before = tid * n_ut_per_thr + bli_min( tid, n_ut_pt_left ); + + PGUARD printf( "n_ut_before: %7ld\n", n_ut_before ); + PGUARD printf( "---------------------------\n" ); + + // + // -- Step 3: Compute the starting j/i utile offset for a given tid -------- + // + + dim_t j_st; + dim_t i_st; + + if ( n_ut_before < n_ut_rect ) + { + // This branch handles scenarios where the number of microtiles + // assigned to lower thread ids is strictly less than the number of + // utiles in the rectangular region. This means that calculating the + // starting microtile index is easy (because it does not need to + // take the location of the diagonal into account). + + PGUARD printf( "Rectangular region: n_ut_before < n_ut_rect\n" ); + PGUARD printf( "\n" ); + + const dim_t ut_index_rect_st = n_ut_before; + + PGUARD printf( "ut_index_st: %7ld\n", ut_index_rect_st ); + PGUARD printf( "---------------------------\n" ); + + j_st = ut_index_rect_st / n_ut_per_col; + i_st = ut_index_rect_st % n_ut_per_col; + + PGUARD printf( "j_st, i_st (fnl=) %4ld,%4ld\n", j_st, i_st ); + } + else // if ( n_ut_rect <= n_ut_before ) + { + // This branch handles scenarios where the number of microtiles + // assigned to lower thread ids exceeds (or equals) the number of + // utiles in the rectangular region. This means we need to observe the + // location of the diagonal to see how many utiles are referenced per + // column of utiles. + + PGUARD printf( "Diagonal region: n_ut_rect <= n_ut_before\n" ); + PGUARD printf( "\n" ); + + // This will be the number of microtile columns we will immediately + // advance past to get to the diagonal region. + const dim_t n_ut_col_adv = offn_ut_nonrect; + + PGUARD printf( "n_ut_col_adv: %7ld\n", n_ut_col_adv ); + + // In order to find j_st and i_st, we need to "allocate" n_ut_before + // microtiles. + dim_t n_ut_tba = n_ut_before; + + PGUARD printf( "n_ut_tba: %7ld\n", n_ut_tba ); + + // Advance past the rectangular region, decrementing n_ut_tba + // accordingly. + n_ut_tba -= n_ut_per_col * n_ut_col_adv; + + PGUARD printf( "n_ut_tba_1: %7ld\n", n_ut_tba ); + PGUARD printf( "\n" ); + + // In case n_ut_tba == 0. Only happens when n_ut_before == n_ut_rect. + j_st = n_ut_col_adv; + i_st = 0; + + for ( dim_t j = n_ut_col_adv; 0 < n_ut_tba; ++j ) + { + const dim_t diagoff_j = diagoffmin - j*nr; + const dim_t n_ut_skip_j = bli_max( -diagoff_j / mr, 0 ); + const dim_t n_ut_this_col = n_ut_per_col - n_ut_skip_j; + + PGUARD printf( "j: %7ld\n", j ); + PGUARD printf( "diagoff_j: %7ld\n", diagoff_j ); + PGUARD printf( "n_ut_skip_j: %7ld\n", n_ut_skip_j ); + PGUARD printf( "n_ut_this_col: %7ld\n", n_ut_this_col ); + PGUARD printf( "n_ut_tba_j0: %7ld\n", n_ut_tba ); + + if ( n_ut_tba < n_ut_this_col ) + { + // If the number of utiles to allocate is less than the number + // in this column, we know that j_st will refer to the current + // column. To find i_st, we first skip to the utile that + // intersects the diagonal and then add n_ut_tba. + j_st = j; + i_st = n_ut_skip_j + n_ut_tba; + PGUARD printf( "j_st, i_st (fnl<) %4ld,%4ld\n", j_st, i_st ); + } + else if ( n_ut_tba == n_ut_this_col ) + { + // If the number of utiles to allocate is exactly equal to the + // number in this column, we know that j_st will refer to the + // *next* column. But to find i_st, we will have to take the + // location of the diagonal into account. + const doff_t diagoff_jp1 = diagoff_j - nr; + const dim_t n_ut_skip_jp1 = bli_max( -diagoff_jp1 / mr, 0 ); + + j_st = j + 1; + i_st = n_ut_skip_jp1; + PGUARD printf( "j_st, i_st (fnl=) %4ld,%4ld\n", j_st, i_st ); + } + + // No matter what (especially if the number of utiles to allocate + // exceeds the number in this column), we decrement n_ut_tba attempt + // to continue to the next iteration. (Note: If either of the two + // branches above is triggered, n_ut_tba will be decremented down to + // zero (or less), in which case this will be the final iteration.) + n_ut_tba -= n_ut_this_col; + + PGUARD printf( "n_ut_tba_j1: %7ld\n", n_ut_tba ); + PGUARD printf( "\n" ); + } + } + + // + // -- Step 4: Save the results --------------------------------------------- + // + + *j_st_p = j_st; + *i_st_p = i_st; + + #ifdef PRINT_RESULT + printf( "j_st, i_st (mem) %4ld,%4ld (n_ut: %4ld)\n", + j_st, i_st, n_ut_for_me ); + #endif + + // Return the number of utiles that this thread was allocated. + return n_ut_for_me; +} + +// ----------------------------------------------------------------------------- + +dim_t bli_thread_range_tlb_u + ( + const dim_t nt, + const dim_t tid, + const doff_t diagoff, + const dim_t m_iter, + const dim_t n_iter, + const dim_t mr, + const dim_t nr, + inc_t* j_st_p, + inc_t* i_st_p + ) +{ + // This function implements tile-level load balancing for an + // upper-trapezoidal submatrix. This partitioning guarantees that all + // threads are assigned nearly the same number of microtiles-worth of work, + // with a maximum imbalance of one microtile. It makes no effort, however, + // to account for differences in threads' workload that is attributable to + // differences in the number of edge-case (or diagonal-intersecting) + // microtiles (which incur slightly more work since they must first write + // to a temporary microtile before updating the output C matrix). + + // Assumption: diagoff < nr. Make sure to prune leading columns beforehand! + if ( nr <= diagoff ) bli_abort(); + + // + // -- Step 1: Compute the computational area of the region ----------------- + // + + // Compute the m and n dimensions according to m_iter and n_iter. (These + // m and n dims will likely be larger than the actual m and n since they + // "round up" the edge case microtiles into full-sized microtiles.) + const dim_t m = m_iter * mr; + const dim_t n = n_iter * nr; + + // For the purposes of many computations in this function, we aren't + // interested in the extent to which diagoff exceeds -m (if it does) + // So we use a new variable that is guaranteed to be no less than -m. + const doff_t diagoffmin = bli_max( diagoff, -m ); + + const dim_t m_rect = m; + const dim_t n_rect = ( -diagoffmin / nr ) * nr; + + const dim_t rect_area = m_rect * n_rect; + const dim_t nonrect_area = m * n - rect_area; + + const dim_t offn_rect = n - n_rect; + //const dim_t offn_nonrect = 0; + const dim_t diagoff_nonrect = diagoffmin; + + const dim_t n_nonrect = n - n_rect; + + const dim_t offn_ut_rect = n_iter + ( diagoffmin / nr ); + + PGUARD printf( "---------------------------\n" ); + PGUARD printf( "max(diagoff,-m): %7ld\n", diagoffmin ); + PGUARD printf( "offn_ut_rect: %7ld\n", offn_ut_rect ); + PGUARD printf( "offn_rect: %7ld\n", offn_rect ); + PGUARD printf( "diagoff_nonrect: %7ld\n", diagoff_nonrect ); + PGUARD printf( "n_nonrect: %7ld\n", n_nonrect ); + PGUARD printf( "---------------------------\n" ); + + dim_t num_unref_ut = 0; + + // Count the number of unreferenced utiles strictly below the diagonal. + for ( dim_t j = 0; j < n_nonrect; j += nr ) + { + const dim_t diagoff_j = diagoff_nonrect - j; + + // diagoff_j will always be at most nr - 1, but will typically be + // negative. This is because the non-rectangular region's diagonal + // offset will be at most nr - 1 for the first column of microtiles, + // since if it were more than nr - 1, that column would have already + // been pruned away (prior to this function being called). + // NOTE: We use bli_max() to ensure that ( m + diagoff_j - nr ) / mr + // does not become negative, which can happen in some situations + // during the first iteration if diagoff is relatively close to -m. + // NOTE: We subtract nr from diagoff_j since it's really the diagonal + // offset of the *next* column of utiles that needs to be used to + // determine how many utiles are referenced in the current column. + const dim_t num_unref_ut_j = bli_max( ( m + diagoff_j - nr ) / mr, 0 ); + + num_unref_ut += num_unref_ut_j; + + PGUARD printf( "j %7ld\n", j ); + PGUARD printf( "diagoff_j - nr %7ld\n", diagoff_j - nr ); + PGUARD printf( "num_unref_ut_j %7ld\n", num_unref_ut_j ); + PGUARD printf( "num_unref_ut %7ld\n", num_unref_ut ); + PGUARD printf( "\n" ); + } + PGUARD printf( "---------------------------\n" ); + + const dim_t tri_unref_area = num_unref_ut * mr * nr; + const dim_t tri_ref_area = nonrect_area - tri_unref_area; + const dim_t total_ref_area = rect_area + tri_ref_area; + + PGUARD printf( "gross area: %7ld\n", m * n ); + PGUARD printf( "rect_area: %7ld\n", rect_area ); + PGUARD printf( "nonrect_area: %7ld\n", nonrect_area ); + PGUARD printf( "tri_unref_area: %7ld\n", tri_unref_area ); + PGUARD printf( "tri_ref_area: %7ld\n", tri_ref_area ); + PGUARD printf( "total_ref_area: %7ld\n", total_ref_area ); + PGUARD printf( "---------------------------\n" ); + + // + // -- Step 2: Compute key utile counts (per thread, per column, etc.) ------ + // + + const dim_t n_ut_ref = total_ref_area / ( mr * nr ); + const dim_t n_ut_tri_ref = tri_ref_area / ( mr * nr ); + //const dim_t n_ut_rect = rect_area / ( mr * nr ); + + PGUARD printf( "n_ut_ref: %7ld\n", n_ut_ref ); + PGUARD printf( "n_ut_tri_ref: %7ld\n", n_ut_tri_ref ); + //PGUARD printf( "n_ut_rect: %7ld\n", n_ut_rect ); + PGUARD printf( "---------------------------\n" ); + + // Compute the number of microtiles to allocate per thread as well as the + // number of leftover microtiles. + const dim_t n_ut_per_thr = n_ut_ref / nt; + const dim_t n_ut_pt_left = n_ut_ref % nt; + + PGUARD printf( "n_ut_per_thr: %7ld\n", n_ut_per_thr ); + PGUARD printf( "n_ut_pt_left: %7ld\n", n_ut_pt_left ); + PGUARD printf( "---------------------------\n" ); + + const dim_t n_ut_per_col = m_iter; + + PGUARD printf( "n_ut_per_col: %7ld\n", n_ut_per_col ); + + // Allocate one of the leftover microtiles to the current thread if its + // tid is one of the lower thread ids. + const dim_t n_ut_for_me = n_ut_per_thr + ( tid < n_ut_pt_left ? 1 : 0 ); + + PGUARD printf( "n_ut_for_me: %7ld (%ld+%ld)\n", n_ut_for_me, + n_ut_per_thr, n_ut_for_me - n_ut_per_thr ); + + // Compute the number of utiles prior to the current thread's starting + // point. This is the sum of all n_ut_for_me for all thread ids less + // than tid. Notice that the second half of this expression effectively + // adds one extra microtile for each lower-valued thread id, up to + // n_ut_pt_left. + const dim_t n_ut_before = tid * n_ut_per_thr + bli_min( tid, n_ut_pt_left ); + + PGUARD printf( "n_ut_before: %7ld\n", n_ut_before ); + PGUARD printf( "---------------------------\n" ); + + // + // -- Step 3: Compute the starting j/i utile offset for a given tid -------- + // + + dim_t j_st; + dim_t i_st; + + if ( n_ut_tri_ref <= n_ut_before ) + { + // This branch handles scenarios where the number of microtiles + // assigned to lower thread ids exceeds (or equals) the number of + // utiles in the diagonal region. This means that calculating the + // starting microtile index is easy (because it does not need to + // take the location of the diagonal into account). + + PGUARD printf( "Rectangular region: n_ut_tri_ref <= n_ut_before\n" ); + PGUARD printf( "\n" ); + + const dim_t ut_index_rect_st = n_ut_before - n_ut_tri_ref; + + PGUARD printf( "ut_index_rect_st: %7ld\n", ut_index_rect_st ); + PGUARD printf( "---------------------------\n" ); + + j_st = offn_ut_rect + ut_index_rect_st / n_ut_per_col; + i_st = ut_index_rect_st % n_ut_per_col; + + PGUARD printf( "j_st, i_st (fnl=) %4ld,%4ld\n", j_st, i_st ); + } + else // if ( n_ut_before < n_ut_tri_ref ) + { + // This branch handles scenarios where the number of microtiles + // assigned to lower thread ids is strictly less than the number of + // utiles in the diagonal region. This means we need to observe the + // location of the diagonal to see how many utiles are referenced per + // column of utiles. + + PGUARD printf( "Diagonal region: n_ut_before < n_ut_tri_ref\n" ); + PGUARD printf( "\n" ); + + // This will be the number of microtile columns we will immediately + // advance past to get to the diagonal region. + const dim_t n_ut_col_adv = 0; + + PGUARD printf( "n_ut_col_adv: %7ld\n", n_ut_col_adv ); + + // In order to find j_st and i_st, we need to "allocate" n_ut_before + // microtiles. + dim_t n_ut_tba = n_ut_before; + + PGUARD printf( "n_ut_tba: %7ld\n", n_ut_tba ); + + // No need to advance since the upper-trapezoid begins with the + // diagonal region. + //n_ut_tba -= 0; + + PGUARD printf( "n_ut_tba_1: %7ld\n", n_ut_tba ); + PGUARD printf( "\n" ); + + // In case n_ut_tba == 0. Only happens when n_ut_before == 0. + j_st = 0; + i_st = 0; + + for ( dim_t j = n_ut_col_adv; 0 < n_ut_tba; ++j ) + { + const dim_t diagoff_j = diagoffmin - j*nr; + const dim_t n_ut_skip_j = bli_max( ( m + diagoff_j - nr ) / mr, 0 ); + const dim_t n_ut_this_col = n_ut_per_col - n_ut_skip_j; + + PGUARD printf( "j: %7ld\n", j ); + PGUARD printf( "diagoff_j: %7ld\n", diagoff_j ); + PGUARD printf( "n_ut_skip_j: %7ld\n", n_ut_skip_j ); + PGUARD printf( "n_ut_this_col: %7ld\n", n_ut_this_col ); + PGUARD printf( "n_ut_tba_j0: %7ld\n", n_ut_tba ); + + if ( n_ut_tba < n_ut_this_col ) + { + // If the number of utiles to allocate is less than the number + // in this column, we know that j_st will refer to the current + // column. To find i_st, we simply use n_ut_tba. + j_st = j; + i_st = n_ut_tba; + PGUARD printf( "j_st, i_st (fnl<) %4ld,%4ld\n", j_st, i_st ); + } + else if ( n_ut_tba == n_ut_this_col ) + { + // If the number of utiles to allocate is exactly equal to the + // number in this column, we know that j_st will refer to the + // *next* column. In this situation, i_st will always be 0. + + j_st = j + 1; + i_st = 0; + PGUARD printf( "j_st, i_st (fnl=) %4ld,%4ld\n", j_st, i_st ); + } + + // No matter what (especially if the number of utiles to allocate + // exceeds the number in this column), we decrement n_ut_tba attempt + // to continue to the next iteration. (Note: If either of the two + // branches above is triggered, n_ut_tba will be decremented down to + // zero (or less), in which case this will be the final iteration.) + n_ut_tba -= n_ut_this_col; + + PGUARD printf( "n_ut_tba_j1: %7ld\n", n_ut_tba ); + PGUARD printf( "\n" ); + } + } + + // + // -- Step 4: Save the results --------------------------------------------- + // + + *j_st_p = j_st; + *i_st_p = i_st; + + #ifdef PRINT_RESULT + printf( "j_st, i_st (mem) %4ld,%4ld (n_ut: %4ld)\n", + j_st, i_st, n_ut_for_me ); + #endif + + // Return the number of utiles that this thread was allocated. + return n_ut_for_me; +} + +// ----------------------------------------------------------------------------- + +dim_t bli_thread_range_tlb_d + ( + const dim_t nt, + const dim_t tid, + const dim_t m_iter, + const dim_t n_iter, + const dim_t mr, + const dim_t nr, + inc_t* j_st_p, + inc_t* i_st_p + ) +{ + // This function implements tile-level load balancing for a + // general/dense submatrix. This partitioning guarantees that all + // threads are assigned nearly the same number of microtiles-worth of work, + // with a maximum imbalance of one microtile. It makes no effort, however, + // to account for differences in threads' workload that is attributable to + // differences in the number of edge-case microtiles (which incur slightly + // more work since they must first write to a temporary microtile before + // updating the output C matrix). + + // + // -- Step 1: Compute the computational area of the region ----------------- + // + + // Compute the m and n dimensions according to m_iter and n_iter. (These + // m and n dims will likely be larger than the actual m and n since they + // "round up" the edge case microtiles into full-sized microtiles.) + const dim_t m = m_iter * mr; + const dim_t n = n_iter * nr; + + const dim_t m_rect = m; + const dim_t n_rect = n; + + const dim_t total_ref_area = m_rect * n_rect; + + PGUARD printf( "total_ref_area: %7ld\n", total_ref_area ); + PGUARD printf( "---------------------------\n" ); + + // + // -- Step 2: Compute key utile counts (per thread, per column, etc.) ------ + // + + const dim_t n_ut_ref = total_ref_area / ( mr * nr ); + + PGUARD printf( "n_ut_ref: %7ld\n", n_ut_ref ); + PGUARD printf( "---------------------------\n" ); + + // Compute the number of microtiles to allocate per thread as well as the + // number of leftover microtiles. + const dim_t n_ut_per_thr = n_ut_ref / nt; + const dim_t n_ut_pt_left = n_ut_ref % nt; + + PGUARD printf( "n_ut_per_thr: %7ld\n", n_ut_per_thr ); + PGUARD printf( "n_ut_pt_left: %7ld\n", n_ut_pt_left ); + PGUARD printf( "---------------------------\n" ); + + const dim_t n_ut_per_col = m_iter; + + PGUARD printf( "n_ut_per_col: %7ld\n", n_ut_per_col ); + + // Allocate one of the leftover microtiles to the current thread if its + // tid is one of the lower thread ids. + const dim_t n_ut_for_me = n_ut_per_thr + ( tid < n_ut_pt_left ? 1 : 0 ); + + PGUARD printf( "n_ut_for_me: %7ld (%ld+%ld)\n", n_ut_for_me, + n_ut_per_thr, n_ut_for_me - n_ut_per_thr ); + + // Compute the number of utiles prior to the current thread's starting + // point. This is the sum of all n_ut_for_me for all thread ids less + // than tid. Notice that the second half of this expression effectively + // adds one extra microtile for each lower-valued thread id, up to + // n_ut_pt_left. + const dim_t n_ut_before = tid * n_ut_per_thr + bli_min( tid, n_ut_pt_left ); + + PGUARD printf( "n_ut_before: %7ld\n", n_ut_before ); + PGUARD printf( "---------------------------\n" ); + + // + // -- Step 3: Compute the starting j/i utile offset for a given tid -------- + // + + const dim_t ut_index_st = n_ut_before; + + PGUARD printf( "ut_index_st: %7ld\n", ut_index_st ); + PGUARD printf( "---------------------------\n" ); + + const dim_t j_st = ut_index_st / n_ut_per_col; + const dim_t i_st = ut_index_st % n_ut_per_col; + + // + // -- Step 4: Save the results --------------------------------------------- + // + + *j_st_p = j_st; + *i_st_p = i_st; + + #ifdef PRINT_RESULT + printf( "j_st, i_st (mem) %4ld,%4ld (n_ut: %4ld)\n", + j_st, i_st, n_ut_for_me ); + #endif + + // Return the number of utiles that this thread was allocated. + return n_ut_for_me; +} + +// ----------------------------------------------------------------------------- + +BLIS_INLINE dim_t bli_tlb_trmm_lx_k_iter + ( + const doff_t diagoff_iter, + const uplo_t uplo, + const dim_t k_iter, + const dim_t ir_iter + ) +{ + if ( bli_is_lower( uplo ) ) + return bli_min( diagoff_iter + ( ir_iter + 1 ), k_iter ); + else // if ( bli_is_upper( uplo ) ) + return k_iter - bli_max( diagoff_iter + ir_iter, 0 ); +} + +BLIS_INLINE dim_t bli_tlb_trmm_rl_k_iter + ( + const doff_t diagoff_iter, + const dim_t k_iter, + const dim_t jr_iter + ) +{ + return k_iter - bli_max( -diagoff_iter + jr_iter, 0 ); +} + +// ----------------------------------------------------------------------------- + +dim_t bli_thread_range_tlb_trmm_ll + ( + const dim_t nt, + const dim_t tid, + const doff_t diagoff, + const dim_t m_iter, + const dim_t n_iter, + const dim_t k_iter, + const dim_t mr, + const dim_t nr, + inc_t* j_st_p, + inc_t* i_st_p + ) +{ + return bli_thread_range_tlb_trmm_lx_impl + ( + nt, tid, diagoff, BLIS_LOWER, m_iter, n_iter, k_iter, mr, nr, + j_st_p, i_st_p + ); +} + +dim_t bli_thread_range_tlb_trmm_lu + ( + const dim_t nt, + const dim_t tid, + const doff_t diagoff, + const dim_t m_iter, + const dim_t n_iter, + const dim_t k_iter, + const dim_t mr, + const dim_t nr, + inc_t* j_st_p, + inc_t* i_st_p + ) +{ + return bli_thread_range_tlb_trmm_lx_impl + ( + nt, tid, diagoff, BLIS_UPPER, m_iter, n_iter, k_iter, mr, nr, + j_st_p, i_st_p + ); +} + +dim_t bli_thread_range_tlb_trmm_lx_impl + ( + const dim_t nt, + const dim_t tid, + const doff_t diagoff, + const uplo_t uplo, + const dim_t m_iter, + const dim_t n_iter, + const dim_t k_iter, + const dim_t mr, + const dim_t nr, + inc_t* j_st_p, + inc_t* i_st_p + ) +{ + // Assumption: 0 <= diagoff (lower); diagoff <= 0 (upper). + // Make sure to prune leading rows (lower) or columns (upper) beforehand! + if ( bli_is_lower( uplo ) && diagoff < 0 ) bli_abort(); + else if ( bli_is_upper( uplo ) && diagoff > 0 ) bli_abort(); + + // Single-threaded cases are simple and allow early returns. + if ( nt == 1 ) + { + const dim_t n_ut_for_me = m_iter * n_iter; + + *j_st_p = 0; + *i_st_p = 0; + + return n_ut_for_me; + } + + // + // -- Step 1: Compute the computational flop cost of each utile column ----- + // + + // Normalize the diagonal offset by mr so that it represents the offset in + // units of mr x mr chunks. + const doff_t diagoff_iter = diagoff / mr; + + // Determine the actual k dimension, in units of mr x mr iterations, capped + // by the k_iter given by the caller. + + PGUARD printf( "---------------------------\n" ); + PGUARD printf( "m_iter: %7ld\n", m_iter ); + PGUARD printf( "n_iter: %7ld\n", n_iter ); + PGUARD printf( "k_iter: %7ld\n", k_iter ); + PGUARD printf( "mr: %7ld\n", mr ); + PGUARD printf( "nr: %7ld\n", nr ); + PGUARD printf( "diagoff_iter: %7ld\n", diagoff_iter ); + + dim_t uops_per_col = 0; + + // Compute the computation flop cost of each microtile column, normalized + // by the number of flops performed by each mr x nr rank-1 update. This + // is simply the sum of all of the k dimensions of each micropanel, up to + // and including (lower) or starting from (upper) the part that intersects + // the diagonal, or the right (lower) or left (upper) edge of the matrix, + // as applicable. + for ( dim_t i = 0; i < m_iter; ++i ) + { + // Don't allow k_a1011 to exceed k_iter, which is the maximum possible + // k dimension (in units of mr x mr chunks of micropanel). + const dim_t k_i_iter + = bli_tlb_trmm_lx_k_iter( diagoff_iter, uplo, k_iter, i ); + + uops_per_col += k_i_iter; + } + + PGUARD printf( "uops_per_col: %7ld\n", uops_per_col ); + + // + // -- Step 2: Compute key flop counts (per thread, per column, etc.) ------- + // + + // Compute the total cost for the entire block-panel multiply. + const dim_t total_uops = uops_per_col * n_iter; + + // Compute the number of microtile ops to allocate per thread as well as the + // number of leftover microtile ops. + const dim_t n_uops_per_thr = total_uops / nt; + const dim_t n_uops_pt_left = total_uops % nt; + + PGUARD printf( "---------------------------\n" ); + PGUARD printf( "total_uops: %7ld\n", total_uops ); + PGUARD printf( "n_uops_per_thr: %7ld\n", n_uops_per_thr ); + PGUARD printf( "n_uops_pt_left: %7ld\n", n_uops_pt_left ); + + // + // -- Step 3: Compute the starting j/i utile offset for a given tid -------- + // + + PGUARD printf( "---------------------------\n" ); + PGUARD printf( "total_utiles: %7ld\n", m_iter * n_iter ); + PGUARD printf( "---------------------------\n" ); + + dim_t j_st_cur = 0; dim_t j_en_cur = 0; + dim_t i_st_cur = 0; dim_t i_en_cur = 0; + + PGUARD printf( " tid %ld will start at j,i: %ld %ld\n", + ( dim_t )0, j_st_cur, i_st_cur ); + + // Find the utile update that pushes uops_tba to 0 or less. +#ifdef PRINT_MODE + for ( dim_t tid_i = 0; tid_i < nt; ++tid_i ) +#else + for ( dim_t tid_i = 0; tid_i < nt - 1; ++tid_i ) +#endif + { + const dim_t uops_ta = n_uops_per_thr + ( tid_i < n_uops_pt_left ? 1 : 0 ); + dim_t uops_tba = uops_ta; + dim_t j = j_st_cur; + dim_t n_ut_for_me = 0; + bool done_e = FALSE; + + PGUARD printf( "tid_i: %ld n_uops to alloc: %3ld \n", tid_i, uops_tba ); + + // This code begins allocating uops when the starting point is somewhere + // after the first microtile. Typically this will not be enough to + // allocate all uops, except for small matrices (and/or high numbers of + // threads), in which case the code signals an early finish (via done_e). + if ( 0 < i_st_cur ) + { + dim_t i; + + //PGUARD printf( "tid_i: %ld uops left to alloc: %2ld \n", tid_i, uops_tba ); + + for ( i = i_st_cur; i < m_iter; ++i ) + { + n_ut_for_me += 1; + + const dim_t uops_tba_new + = uops_tba - + bli_tlb_trmm_lx_k_iter( diagoff_iter, uplo, k_iter, i ); + + uops_tba = uops_tba_new; + + PGUARD printf( "tid_i: %ld i: %2ld (1 n_ut_cur: %ld) (uops_alloc: %ld)\n", + tid_i, i, n_ut_for_me, uops_ta - uops_tba ); + + if ( uops_tba_new <= 0 ) { j_en_cur = j; i_en_cur = i; done_e = TRUE; + break; } + } + + if ( i == m_iter ) j += 1; + } + + // This code advances over as many columns of utiles as possible and then + // walks down to the correct utile within the subsequent column. However, + // it gets skipped entirely if the previous code block was able to + // allocate all of the current tid's uops. + if ( !done_e ) + { + const dim_t j_inc0 = uops_tba / uops_per_col; + const dim_t j_left0 = uops_tba % uops_per_col; + + // We need to set a hard limit on how much j_inc can be. Namely, + // it should not exceed the number of utile columns that are left + // in the matrix. We also correctly compute j_left when the initial + // computation of j_inc0 above exceeds the revised j_inc, but this + // is mostly only so that in these situations the debug statements + // report the correct numbers. + const dim_t j_inc = bli_min( j_inc0, n_iter - j ); + const dim_t delta = j_inc0 - j_inc; + const dim_t j_left = j_left0 + delta * uops_per_col; + + // Increment j by the number of full utile columns we allocate, and + // set the remaining utile ops to be allocated to the remainder. + j += j_inc; + uops_tba = j_left; + + n_ut_for_me += j_inc * m_iter; + + PGUARD printf( "tid_i: %ld advanced to col: %2ld (uops traversed: %ld)\n", + tid_i, j, uops_per_col * j_inc ); + PGUARD printf( "tid_i: %ld j: %2ld ( n_ut_cur: %ld) (uops_alloc: %ld)\n", + tid_i, j, n_ut_for_me, uops_ta - uops_tba ); + PGUARD printf( "tid_i: %ld uops left to alloc: %2ld \n", tid_i, j_left ); + + if ( uops_tba == 0 ) + { + // If advancing j_inc columns allocated all of our uops, then + // designate the last iteration of the previous column as the + // end point. + j_en_cur = j - 1; + i_en_cur = m_iter - 1; + } + else if ( j > n_iter ) bli_abort(); // safety check. + else if ( j == n_iter ) + { + // If we still have at least some uops to allocate, and advancing + // j_inc columns landed us at the beginning of the first non- + // existent column (column n_iter), then we're done. (The fact + // that we didn't get to allocate all of our uops just means that + // the lower tids slightly overshot their allocations, leaving + // fewer uops for the last thread.) + } + else // if ( 0 < uops_tba && j < n_iter ) + { + // If we have at least some uops to allocate, and we still have + // at least some columns to process, then we search for the + // utile that will put us over the top. + + for ( dim_t i = 0; i < m_iter; ++i ) + { + n_ut_for_me += 1; + + const dim_t uops_tba_new + = uops_tba - + bli_tlb_trmm_lx_k_iter( diagoff_iter, uplo, k_iter, i ); + + uops_tba = uops_tba_new; + + PGUARD printf( "tid_i: %ld i: %2ld (4 n_ut_cur: %ld) (uops_alloc: %ld)\n", + tid_i, i, n_ut_for_me, uops_ta - uops_tba ); + + if ( uops_tba_new <= 0 ) { j_en_cur = j; i_en_cur = i; + break; } + } + } + } + + + PGUARD printf( "tid_i: %ld (5 n_ut_cur: %ld) (overshoot: %ld out of %ld)\n", + tid_i, n_ut_for_me, -uops_tba, uops_ta ); + + if ( tid_i == tid ) + { + *j_st_p = j_st_cur; + *i_st_p = i_st_cur; + return n_ut_for_me; + } + + // Use the current tid's ending i,j values to determine the starting i,j + // values for the next tid. + j_st_cur = j_en_cur; + i_st_cur = i_en_cur + 1; + if ( i_st_cur == m_iter ) { j_st_cur += 1; i_st_cur = 0; } + + PGUARD printf( "tid_i: %ld (6 n_ut_cur: %ld)\n", + tid_i, n_ut_for_me ); + PGUARD printf( "tid_i: %ld tid %ld will start at j,i: %ld %ld\n", + tid_i, tid_i + 1, j_st_cur, i_st_cur ); + PGUARD printf( "---------------------------\n" ); + } + +#ifndef PRINT_MODE + + // + // -- Step 4: Handle the last thread's allocation -------------------------- + // + + // An optimization: The above loop runs to nt - 1 rather than nt since it's + // easy to count the number of utiles allocated to the last thread. + const dim_t n_ut_for_me = m_iter - i_st_cur + + (n_iter - j_st_cur - 1) * m_iter; + *j_st_p = j_st_cur; + *i_st_p = i_st_cur; + + PGUARD printf( "tid_i: %ld (7 n_ut_for_me: %ld) (j,i_st: %ld %ld)\n", + tid, n_ut_for_me, j_st_cur, i_st_cur ); + + return n_ut_for_me; +#else + // This line should never execute, but we need it to satisfy the compiler. + return -1; +#endif +} + +// ----------------------------------------------------------------------------- + +#if 0 +dim_t bli_thread_range_tlb_trmm_r + ( + const dim_t nt, + const dim_t tid, + const doff_t diagoff, + const uplo_t uplo, + const dim_t m_iter, + const dim_t n_iter, + const dim_t k_iter, + const dim_t mr, + const dim_t nr, + inc_t* j_st_p, + inc_t* i_st_p + ) +{ + dim_t n_ut_for_me; + + if ( bli_is_lower( uplo ) ) + { + inc_t j_en_l, i_en_l; + + n_ut_for_me = bli_thread_range_tlb_trmm_rl_impl + ( + nt, tid, diagoff, m_iter, n_iter, k_iter, mr, nr, + j_st_p, i_st_p, &j_en_l, &i_en_l + ); + } + else // if ( bli_is_upper( uplo ) ) + { + inc_t j_st_l, i_st_l; + inc_t j_en_l, i_en_l; + + // Reverse the effective tid and use the diagonal offset as if the m and + // n dimension were reversed (similar to a 180 degree rotation). This + // transforms the problem into one of allocating ranges for a lower- + // triangular matrix, for which we already have a special routine. + const dim_t tid_rev = nt - tid - 1; + const doff_t diagoff_rev = nr*n_iter - ( nr*k_iter + diagoff ); + + n_ut_for_me = bli_thread_range_tlb_trmm_rl_impl + ( + nt, tid_rev, diagoff_rev, m_iter, n_iter, k_iter, mr, nr, + &j_st_l, &i_st_l, &j_en_l, &i_en_l + ); + + // The ending j and i offsets will serve as our starting offsets + // returned to the caller, but first we have to reverse the offsets so + // that their semantics are once again relative to an upper-triangular + // matrix. + j_en_l = n_iter - j_en_l - 1; + i_en_l = m_iter - i_en_l - 1; + + *j_st_p = j_en_l; + *i_st_p = i_en_l; + } + + return n_ut_for_me; +} +#endif + +dim_t bli_thread_range_tlb_trmm_rl + ( + const dim_t nt, + const dim_t tid, + const doff_t diagoff, + const dim_t m_iter, + const dim_t n_iter, + const dim_t k_iter, + const dim_t mr, + const dim_t nr, + inc_t* j_st_p, + inc_t* i_st_p + ) +{ + inc_t j_en_l, i_en_l; + + return bli_thread_range_tlb_trmm_rl_impl + ( + nt, tid, diagoff, m_iter, n_iter, k_iter, mr, nr, + j_st_p, i_st_p, &j_en_l, &i_en_l + ); +} + +dim_t bli_thread_range_tlb_trmm_ru + ( + const dim_t nt, + const dim_t tid, + const doff_t diagoff, + const dim_t m_iter, + const dim_t n_iter, + const dim_t k_iter, + const dim_t mr, + const dim_t nr, + inc_t* j_st_p, + inc_t* i_st_p + ) +{ + inc_t j_st_l, i_st_l; + inc_t j_en_l, i_en_l; + + // Reverse the effective tid and use the diagonal offset as if the m and + // n dimension were reversed (similar to a 180 degree rotation). This + // transforms the problem into one of allocating ranges for a lower- + // triangular matrix, for which we already have a special routine. + const dim_t tid_rev = nt - tid - 1; + const doff_t diagoff_rev = nr*n_iter - ( nr*k_iter + diagoff ); + + const dim_t n_ut_for_me = bli_thread_range_tlb_trmm_rl_impl + ( + nt, tid_rev, diagoff_rev, m_iter, n_iter, k_iter, mr, nr, + &j_st_l, &i_st_l, &j_en_l, &i_en_l + ); + + // The ending j and i offsets will serve as our starting offsets + // returned to the caller, but first we have to reverse the offsets so + // that their semantics are once again relative to an upper-triangular + // matrix. + j_en_l = n_iter - j_en_l - 1; + i_en_l = m_iter - i_en_l - 1; + + *j_st_p = j_en_l; + *i_st_p = i_en_l; + + return n_ut_for_me; +} + +dim_t bli_thread_range_tlb_trmm_rl_impl + ( + const dim_t nt, + const dim_t tid, + const doff_t diagoff, + const dim_t m_iter, + const dim_t n_iter, + const dim_t k_iter, + const dim_t mr, + const dim_t nr, + inc_t* j_st_p, + inc_t* i_st_p, + inc_t* j_en_p, + inc_t* i_en_p + ) +{ + // Assumption: 0 <= diagoff. Make sure to prune leading rows beforehand! + if ( diagoff < 0 ) bli_abort(); + + // Single-threaded cases are simple and allow early returns. + if ( nt == 1 ) + { + const dim_t n_ut_for_me = m_iter * n_iter; + + *j_st_p = 0; + *i_st_p = 0; + *j_en_p = n_iter - 1; + *i_en_p = m_iter - 1; + + return n_ut_for_me; + } + + // + // -- Step 1: Compute the computational volume of the region --------------- + // + + // Normalize the diagonal offset by nr so that it represents the offset in + // units of nr x nr chunks. + const doff_t diagoff_iter = diagoff / nr; + + // For the purposes of many computations in this function, we aren't + // interested in the extent to which diagoff exceeds n (if it does) + // So we use a new variable that is guaranteed to be no greater than n. + const doff_t diagoffmin_iter = bli_min( diagoff_iter, n_iter ); + + const dim_t k_rect = k_iter; + const dim_t n_rect = diagoffmin_iter; + + const dim_t gross_area = k_rect * n_iter; + const dim_t rect_area = k_rect * n_rect; + const dim_t nonrect_area = gross_area - rect_area; + + const dim_t offn_nonrect = n_rect; + const dim_t diagoff_nonrect = 0; + + const dim_t n_nonrect = n_iter - n_rect; + + const dim_t offn_ut_nonrect = diagoffmin_iter; + + PGUARD printf( "---------------------------\n" ); + PGUARD printf( "m_iter: %7ld\n", m_iter ); + PGUARD printf( "k_iter: %7ld\n", k_iter ); + PGUARD printf( "n_iter: %7ld\n", n_iter ); + PGUARD printf( "min(diagoff_it,n): %7ld\n", diagoffmin_iter ); + PGUARD printf( "offn_ut_nonrect: %7ld\n", offn_ut_nonrect ); + PGUARD printf( "offn_nonrect: %7ld\n", offn_nonrect ); + PGUARD printf( "diagoff_nonrect: %7ld\n", diagoff_nonrect ); + PGUARD printf( "n_nonrect: %7ld\n", n_nonrect ); + PGUARD printf( "---------------------------\n" ); + + const dim_t num_unref_ut0 = n_nonrect * ( n_nonrect - 1 ) / 2; + const dim_t num_unref_ut = bli_max( 0, num_unref_ut0 ); + + const dim_t tri_unref_area = num_unref_ut; + const dim_t tri_ref_area = nonrect_area - tri_unref_area; + const dim_t total_ref_area = rect_area + tri_ref_area; + const dim_t rect_vol = rect_area * m_iter; + const dim_t tri_ref_vol = tri_ref_area * m_iter; + const dim_t total_vol = total_ref_area * m_iter; + + PGUARD printf( "gross_area: %7ld\n", gross_area ); + PGUARD printf( "nonrect_area: %7ld\n", nonrect_area ); + PGUARD printf( "tri_unref_area: %7ld\n", tri_unref_area ); + PGUARD printf( "rect_area: %7ld\n", rect_area ); + PGUARD printf( "tri_ref_area: %7ld\n", tri_ref_area ); + PGUARD printf( "total_ref_area: %7ld\n", total_ref_area ); + PGUARD printf( "---------------------------\n" ); + PGUARD printf( "rect_vol (uops): %7ld\n", rect_vol ); + PGUARD printf( "tri_ref_vol (uops): %7ld\n", tri_ref_vol ); + PGUARD printf( "total_vol (uops): %7ld\n", total_vol ); + PGUARD printf( "---------------------------\n" ); + + // + // -- Step 2: Compute key flop counts (per thread, per column, etc.) ------- + // + + //const dim_t rect_uops = rect_vol; + //const dim_t tri_ref_uops = tri_ref_vol; + const dim_t total_uops = total_vol; + + // Compute the number of microtile ops to allocate per thread as well as the + // number of leftover microtile ops. + const dim_t n_uops_per_thr = total_uops / nt; + const dim_t n_uops_pt_left = total_uops % nt; + + PGUARD printf( "n_threads: %7ld\n", nt ); + PGUARD printf( "n_uops_per_thr: %7ld\n", n_uops_per_thr ); + PGUARD printf( "n_uops_pt_left: %7ld\n", n_uops_pt_left ); + PGUARD printf( "---------------------------\n" ); + + const dim_t uops_per_col_rect = m_iter * k_iter; + + PGUARD printf( "uops_per_col_rect: %7ld\n", uops_per_col_rect ); + + // Allocate one of the leftover uops to the current thread if its tid is + // one of the lower thread ids. + //const dim_t n_uops_for_me = n_uops_per_thr + ( tid < n_uops_pt_left ? 1 : 0 ); + + //PGUARD printf( "n_uops_for_me: %7ld (%ld+%ld)\n", + // n_uops_for_me, n_uops_per_thr, n_uops_for_me - n_uops_per_thr ); + + // + // -- Step 3: Compute the starting j/i utile offset for a given tid) ------- + // + + PGUARD printf( "---------------------------\n" ); + PGUARD printf( "total_utiles: %7ld\n", m_iter * n_iter ); + PGUARD printf( "---------------------------\n" ); + + dim_t j_st_cur = 0; dim_t j_en_cur = 0; + dim_t i_st_cur = 0; dim_t i_en_cur = 0; + + // Find the utile update that pushes uops_tba to 0 or less. +#ifdef PRINT_MODE + for ( dim_t tid_i = 0; tid_i < nt; ++tid_i ) +#else + for ( dim_t tid_i = 0; tid_i < nt - 1; ++tid_i ) +#endif + { + const dim_t uops_ta = n_uops_per_thr + ( tid_i < n_uops_pt_left ? 1 : 0 ); + dim_t uops_tba = uops_ta; + dim_t j = j_st_cur; + dim_t n_ut_for_me = 0; + bool done_e = FALSE; + bool search_tri = FALSE; + + PGUARD printf( "tid_i: %ld n_uops_ta: %3ld \n", tid_i, uops_tba ); + PGUARD printf( "tid_i: %ld j: %2ld ( n_ut_cur: %ld) (uops_alloc: %ld)\n", + tid_i, j, n_ut_for_me, uops_ta - uops_tba ); + + // This code begins allocating uops when the starting point is somewhere + // after the first microtile. Typically this will not be enough to + // allocate all uops, except for situations where the number of threads + // is high relative to the number of utile columns, in which case the + // code signals an early finish (via done_e). + if ( 0 < i_st_cur ) + { + // Compute the number of uops needed to update each utile in the + // current column. + const dim_t k_iter_j = bli_tlb_trmm_rl_k_iter( diagoff_iter, k_iter, j ); + + dim_t i; + + #if 0 + + // Starting from i_st_cur within the current utile column, allocate + // utiles until (a) we run out of utiles in the column (which is tyipcally + // what happens), or (b) we finish allocating all uops for the current + // thread (uops_tba drops to zero or less). + for ( i = i_st_cur; i < m_iter; ++i ) + { + n_ut_for_me += 1; + + const dim_t uops_tba_new = uops_tba - k_iter_j; + + uops_tba = uops_tba_new; + + PGUARD printf( "tid_i: %ld i: %2ld (0 n_ut_cur: %ld) (uops_alloc: %ld) (k_iter_j: %ld)\n", + tid_i, i, n_ut_for_me, uops_ta - uops_tba, k_iter_j ); + + if ( uops_tba_new <= 0 ) { j_en_cur = j; i_en_cur = i; done_e = TRUE; + break; } + } + + // If we traversed the entire column (regardless of whether we finished + // allocating utiles for the current thread), increment j to the next + // column, which is where we'll continue our search for the current tid + // (or start our search for the next tid if we finished allocating utiles). + // Additionally, if we finished traversing all utile columns, mark the + // last utile of the last column as the end point, and set the "done early" + // flag. + if ( i == m_iter ) + { + j += 1; + if ( j == n_iter ) { j_en_cur = j - 1; i_en_cur = m_iter - 1; done_e = TRUE; } + } + + #else + + // Compute the number of utiles left to allocate under the (probably false) + // assumption that all utiles incur the same uop cost (k_iter_j) to update. + // Also compute the number of utiles that remain in the current column. + const dim_t n_ut_tba_j = uops_tba / k_iter_j + ( uops_tba % k_iter_j ? 1 : 0 ); + const dim_t n_ut_rem_j = m_iter - i_st_cur; + + // Compare the aforementioned values. If n_ut_tba_j is less than or equal to + // the number of remaining utiles in the column, we can finish allocating + // without moving to the next column. But if n_ut_tba_j exceeds n_ut_rem_j, + // then we aren't done yet, so allocate what we can and move on. + if ( n_ut_tba_j <= n_ut_rem_j ) + { + n_ut_for_me += n_ut_tba_j; + uops_tba -= n_ut_tba_j * k_iter_j; + i = i_st_cur + n_ut_tba_j; + + j_en_cur = j; i_en_cur = i - 1; done_e = TRUE; + } + else // if ( n_ut_rem_j < n_ut_tba_j ) + { + n_ut_for_me += n_ut_rem_j; + uops_tba -= n_ut_rem_j * k_iter_j; + i = i_st_cur + n_ut_rem_j; + } + + PGUARD printf( "tid_i: %ld i: %2ld (* n_ut_cur: %ld) (uops_alloc: %ld)\n", + tid_i, i-1, n_ut_for_me, uops_ta - uops_tba ); + + // If we allocated all utiles in the column (regardless of whether we finished + // allocating utiles for the current thread), increment j to the next column, + // which is where we'll continue our search for the current tid's end point + // (or start our search through the next tid's range if we finished allocating + // the current tid's utiles). Additionally, if we allocated utiles from the + // last column, mark the tid's end point and set the "done early" flag. + if ( i == m_iter ) + { + j += 1; i = 0; + if ( j == n_iter ) { j_en_cur = j - 1; i_en_cur = m_iter - 1; done_e = TRUE; } + + PGUARD printf( "tid_i: %ld j: %2ld (! n_ut_cur: %ld) (uops_alloc: %ld)\n", + tid_i, j, n_ut_for_me, uops_ta - uops_tba ); + } + + #endif + } + + // This code advances over as many columns of utiles as possible, within + // the rectangular region (i.e., pre-diagonal), and then walks down to + // the correct utile within the subsequent column. However, note that + // this code gets skipped entirely if the previous code block was able + // to allocate all of the current tid's uops. + if ( !done_e ) + { + // If j is positioned somewhere within the rectangular region, we can + // skip over as many utile columns as possible with some integer math. + // And depending on how many uops we were able to allocate relative to + // the number of columns that exist, we may need to walk through the + // triangular region as well. But if j is already in the triangular + // region, we set a flag so that we execute the code that will walk + // through those columns. + if ( j < diagoff_iter ) + { + const dim_t j_inc0 = uops_tba / uops_per_col_rect; + const dim_t j_left0 = uops_tba % uops_per_col_rect; + + // We need to set a hard limit on how much j_inc can be. Namely, + // it should not exceed the number of utile columns that are left + // in the rectangular region of the matrix, nor should it exceed + // the total number of utile columns that are left. + const dim_t j_inc1 = bli_min( j_inc0, diagoff_iter - j ); + const dim_t j_inc = bli_min( j_inc1, n_iter - j ); + const dim_t delta = j_inc0 - j_inc; + const dim_t j_left = j_left0 + delta * uops_per_col_rect; + + // Increment j by the number of full utile columns we allocate, and + // set the remaining utile ops to be allocated to the remainder. + j += j_inc; + uops_tba = j_left; + + n_ut_for_me += j_inc * m_iter; + + PGUARD printf( "tid_i: %ld advanced to col: %2ld (uops traversed: %ld)\n", + tid_i, j, uops_per_col_rect * j_inc ); + PGUARD printf( "tid_i: %ld j: %2ld (1 n_ut_cur: %ld) (uops_alloc: %ld)\n", + tid_i, j, n_ut_for_me, uops_ta - uops_tba ); + PGUARD printf( "tid_i: %ld uops left to alloc: %2ld \n", tid_i, j_left ); + + if ( uops_tba == 0 ) + { + // If advancing j_inc columns allocated all of our uops, then + // designate the last iteration of the previous column as the + // end point. + j_en_cur = j - 1; + i_en_cur = m_iter - 1; + search_tri = FALSE; + + PGUARD printf( "tid_i: %ld j: %2ld (2 n_ut_cur: %ld) (uops_alloc: %ld)\n", + tid_i, j, n_ut_for_me, uops_ta - uops_tba ); + } + else if ( j > n_iter ) bli_abort(); // Safety check; should never execute. + else if ( j == n_iter ) + { + // If we still have at least some uops to allocate, and advancing + // j_inc columns landed us at the beginning of the first non- + // existent column (column n_iter), then we're done. (The fact + // that we didn't get to allocate all of our uops just means that + // the lower tids slightly overshot their allocations, leaving + // fewer uops for the last thread.) + search_tri = FALSE; + PGUARD printf( "tid_i: %ld j: %2ld (3 n_ut_cur: %ld) (uops_alloc: %ld)\n", + tid_i, j, n_ut_for_me, uops_ta - uops_tba ); + } + else if ( j < diagoff_iter ) + { + // If we still have at least some uops to allocate, and advancing + // j_inc columns landed us at the beginning of a column that is + // still in the rectangular region, then we don't need to enter + // the triangular region (if it even exists). The code below will + // walk down the current column and find the utile that puts us + // over the top. + search_tri = FALSE; + PGUARD printf( "tid_i: %ld j: %2ld (4 n_ut_cur: %ld) (uops_alloc: %ld)\n", + tid_i, j, n_ut_for_me, uops_ta - uops_tba ); + } + else // if ( 0 < uops_tba && j == diagoff_iter && j < n_iter ) + { + // If we have at least some uops to allocate, and we still have + // at least some columns to process, then we set a flag to + // indicate that we still need to step through the triangular + // region. + search_tri = TRUE; + PGUARD printf( "tid_i: %ld j: %2ld (5 n_ut_cur: %ld) (uops_alloc: %ld)\n", + tid_i, j, n_ut_for_me, uops_ta - uops_tba ); + } + } + else /* if ( diagoff_iter <= j ) */ + { + PGUARD printf( "tid_i: %ld j: %2ld >= diagoff_iter: %ld\n", + tid_i, j, diagoff_iter ); + search_tri = TRUE; + } + + PGUARD printf( "tid_i: %ld j: %2ld search_tri: %d\n", tid_i, j, search_tri ); + + if ( search_tri ) + { + // If we still have some uops to allocate in the triangular region, + // we first allocate as many full utile columns as possible without + // exceeding the number of uops left to be allocated. + for ( ; j < n_iter; ++j ) + { + const dim_t k_iter_j = bli_tlb_trmm_rl_k_iter( diagoff_iter, k_iter, j ); + const dim_t n_uops_j = k_iter_j * m_iter; + + PGUARD printf( "tid_i: %ld j: %2ld (6 n_ut_cur: %ld) (uops_alloc: %ld) (n_uops_j: %ld)\n", + tid_i, j, n_ut_for_me, uops_ta - uops_tba, n_uops_j ); + + if ( uops_tba == 0 ) + { + PGUARD printf( "tid_i: %ld j: %2ld (7 n_ut_cur: %ld) (uops_alloc: %ld)\n", + tid_i, j, n_ut_for_me, uops_ta - uops_tba ); + // If advancing over the previous column allocated all of + // our uops, then designate the last iteration of the + // previous column as the end point. + j_en_cur = j - 1; + i_en_cur = m_iter - 1; + break; + } + if ( n_uops_j <= uops_tba ) + { + // If advancing over the current column doesn't exceed the + // number of uops left to allocate, then allocate them. (If + // n_uops_j == uops_tba, then we'll be done shortly after + // incrementing j.) + n_ut_for_me += m_iter; + uops_tba -= n_uops_j; + + PGUARD printf( "tid_i: %ld j: %2ld (8 n_ut_cur: %ld) (uops_alloc: %ld)\n", + tid_i, j, n_ut_for_me, uops_ta - uops_tba ); + } + else // if ( uops_tba < n_uops_j ) + { + PGUARD printf( "tid_i: %ld j: %2ld (9 n_ut_cur: %ld) (uops_alloc: %ld)\n", + tid_i, j, n_ut_for_me, uops_ta - uops_tba ); + // If we can finish allocating all the remaining uops + // with the utiles in the current column, then we break + // out of the loop without updating j, n_ut_for_me, or + // uops_tba. The remaining uops will be allocated in + // the loop over m_iter below. + break; + } + } + } + + // If there are any uops left to allocate, and we haven't already + // exhausted all allocatable utiles, it means that we have to walk down + // the current column and find the utile that puts us over the top. + if ( 0 < uops_tba && j < n_iter ) + { + const dim_t k_iter_j = bli_tlb_trmm_rl_k_iter( diagoff_iter, k_iter, j ); + + PGUARD printf( "tid_i: %ld j: %2ld (A n_ut_cur: %ld) (uops_alloc: %ld) (k_iter_j: %ld)\n", + tid_i, j, n_ut_for_me, uops_ta - uops_tba, k_iter_j ); + + #if 0 + + dim_t i; + for ( i = 0; i < m_iter; ++i ) + { + n_ut_for_me += 1; + const dim_t uops_tba_new = uops_tba - k_iter_j; + uops_tba = uops_tba_new; + PGUARD printf( "tid_i: %ld i: %2ld (B n_ut_cur: %ld) (uops_alloc: %ld)\n", + tid_i, i, n_ut_for_me, uops_ta - uops_tba ); + if ( uops_tba_new <= 0 ) { j_en_cur = j; i_en_cur = i; break; } + } + + if ( i == m_iter ) + { + j += 1; + if ( j == n_iter ) { j_en_cur = j - 1; i_en_cur = m_iter - 1; } + } + + #else + + const dim_t n_ut_j = uops_tba / k_iter_j + ( uops_tba % k_iter_j ? 1 : 0 ); + const dim_t i = n_ut_j - 1; + + uops_tba -= n_ut_j * k_iter_j; + n_ut_for_me += n_ut_j; + + j_en_cur = j; i_en_cur = i; + + PGUARD printf( "tid_i: %ld i: %2ld (b n_ut_cur: %ld) (uops_alloc: %ld)\n", + tid_i, i, n_ut_for_me, uops_ta - uops_tba ); + + #endif + } + else // if ( uops_tba <= 0 || j == n_iter ) + { + j_en_cur = j - 1; + i_en_cur = m_iter - 1; + } + } + + PGUARD printf( "tid_i: %ld done! (C n_ut_cur: %ld) (overshoot: %ld out of %ld)\n", + tid_i, n_ut_for_me, -uops_tba, uops_ta ); + + if ( tid_i == tid ) + { + *j_st_p = j_st_cur; + *i_st_p = i_st_cur; + *j_en_p = j_en_cur; + *i_en_p = i_en_cur; + return n_ut_for_me; + } + + // Use the current tid's ending i,j values to determine the starting i,j + // values for the next tid. + j_st_cur = j_en_cur; + i_st_cur = i_en_cur + 1; + if ( i_st_cur == m_iter ) { j_st_cur += 1; i_st_cur = 0; } + + PGUARD printf( "tid_i: %ld (D n_ut_cur: %ld)\n", + tid_i, n_ut_for_me ); + PGUARD printf( "tid_i: %ld tid %ld will start at j,i: %ld %ld\n", + tid_i, tid_i + 1, j_st_cur, i_st_cur ); + PGUARD printf( "---------------------------\n" ); + } + +#ifndef PRINT_MODE + + // + // -- Step 4: Handle the last thread's allocation -------------------------- + // + + // An optimization: The above loop runs to nt - 1 rather than nt since it's + // easy to count the number of utiles allocated to the last thread. + const dim_t n_ut_for_me = m_iter - i_st_cur + + (n_iter - j_st_cur - 1) * m_iter; + *j_st_p = j_st_cur; + *i_st_p = i_st_cur; + *j_en_p = n_iter - 1; + *i_en_p = m_iter - 1; + + PGUARD printf( "tid_i: %ld (E n_ut_for_me: %ld) (j,i_st: %ld %ld)\n", + tid, n_ut_for_me, j_st_cur, i_st_cur ); + + return n_ut_for_me; +#else + // This line should never execute, but we need it to satisfy the compiler. + return -1; +#endif +} + diff --git a/frame/thread/bli_thread_range_tlb.h b/frame/thread/bli_thread_range_tlb.h new file mode 100644 index 0000000000..b344f09ef8 --- /dev/null +++ b/frame/thread/bli_thread_range_tlb.h @@ -0,0 +1,192 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_THREAD_RANGE_TLB_H +#define BLIS_THREAD_RANGE_TLB_H + +#if 0 +dim_t bli_thread_range_tlb + ( + const dim_t nt, + const dim_t tid, + const doff_t diagoff, + const uplo_t uplo, + const dim_t m_iter, + const dim_t n_iter, + const dim_t mr, + const dim_t nr, + inc_t* j_st_p, + inc_t* i_st_p + ); +#endif +dim_t bli_thread_range_tlb_l + ( + const dim_t nt, + const dim_t tid, + const doff_t diagoff, + const dim_t m_iter, + const dim_t n_iter, + const dim_t mr, + const dim_t nr, + inc_t* j_st_p, + inc_t* i_st_p + ); +dim_t bli_thread_range_tlb_u + ( + const dim_t nt, + const dim_t tid, + const doff_t diagoff, + const dim_t m_iter, + const dim_t n_iter, + const dim_t mr, + const dim_t nr, + inc_t* j_st_p, + inc_t* i_st_p + ); +dim_t bli_thread_range_tlb_d + ( + const dim_t nt, + const dim_t tid, + const dim_t m_iter, + const dim_t n_iter, + const dim_t mr, + const dim_t nr, + inc_t* j_st_p, + inc_t* i_st_p + ); + +// --- + +dim_t bli_thread_range_tlb_trmm_ll + ( + const dim_t nt, + const dim_t tid, + const doff_t diagoff, + const dim_t m_iter, + const dim_t n_iter, + const dim_t k_iter, + const dim_t mr, + const dim_t nr, + inc_t* j_st_p, + inc_t* i_st_p + ); +dim_t bli_thread_range_tlb_trmm_lu + ( + const dim_t nt, + const dim_t tid, + const doff_t diagoff, + const dim_t m_iter, + const dim_t n_iter, + const dim_t k_iter, + const dim_t mr, + const dim_t nr, + inc_t* j_st_p, + inc_t* i_st_p + ); +dim_t bli_thread_range_tlb_trmm_lx_impl + ( + const dim_t nt, + const dim_t tid, + const doff_t diagoff, + const uplo_t uplo, + const dim_t m_iter, + const dim_t n_iter, + const dim_t k_iter, + const dim_t mr, + const dim_t nr, + inc_t* j_st_p, + inc_t* i_st_p + ); +#if 0 +dim_t bli_thread_range_tlb_trmm_r + ( + const dim_t nt, + const dim_t tid, + const doff_t diagoff, + const uplo_t uplo, + const dim_t m_iter, + const dim_t n_iter, + const dim_t k_iter, + const dim_t mr, + const dim_t nr, + inc_t* j_st_p, + inc_t* i_st_p + ); +#endif + +// --- + +dim_t bli_thread_range_tlb_trmm_rl + ( + const dim_t nt, + const dim_t tid, + const doff_t diagoff, + const dim_t m_iter, + const dim_t n_iter, + const dim_t k_iter, + const dim_t mr, + const dim_t nr, + inc_t* j_st_p, + inc_t* i_st_p + ); +dim_t bli_thread_range_tlb_trmm_ru + ( + const dim_t nt, + const dim_t tid, + const doff_t diagoff, + const dim_t m_iter, + const dim_t n_iter, + const dim_t k_iter, + const dim_t mr, + const dim_t nr, + inc_t* j_st_p, + inc_t* i_st_p + ); +dim_t bli_thread_range_tlb_trmm_rl_impl + ( + const dim_t nt, + const dim_t tid, + const doff_t diagoff, + const dim_t m_iter, + const dim_t n_iter, + const dim_t k_iter, + const dim_t mr, + const dim_t nr, + inc_t* j_st_p, + inc_t* i_st_p, + inc_t* j_en_p, + inc_t* i_en_p + ); + +#endif diff --git a/frame/thread/old/bli_thread_range_snake.c b/frame/thread/old/bli_thread_range_snake.c new file mode 100644 index 0000000000..11a287659c --- /dev/null +++ b/frame/thread/old/bli_thread_range_snake.c @@ -0,0 +1,120 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#if 0 +void bli_thread_range_snake_jr + ( + const thrinfo_t* thread, + doff_t diagoff, + uplo_t uplo, + dim_t n, + dim_t bf, + bool handle_edge_low, + dim_t* start, + dim_t* end, + dim_t* inc + ) +{ + // Use snake partitioning of jr loop. + + // NOTE: This function currently assumes that edge cases are handled + // "high" and therefore ignores handle_edge_low. This is because the + // function is only used by gemmt and friends (herk/her2k/syrk/syr2k). + // These operations, unlike trmm/trmm3 and trsm, never require + // low-range edge cases. + + const dim_t tid = bli_thrinfo_work_id( thread ); + const dim_t nt = bli_thrinfo_n_way( thread ); + + const dim_t n_left = n % bf; + const dim_t n_iter = n / bf + ( n_left ? 1 : 0 ); + + if ( bli_is_lower( uplo ) ) + { + // Use the thrinfo_t work id as the thread's starting index. + const dim_t st = tid; + + // This increment will be too big for some threads with only one unit + // (NR columns, or an edge case) of work, but that's okay since all that + // matters is that st + in >= en, which will cause that thread's jr loop + // to not execute beyond the first iteration. + const dim_t in = 2 * ( nt - tid ) - 1; + + dim_t en = st + in + 1; + + // Don't let the thread's end index exceed n_iter. + if ( n_iter < en ) en = n_iter; + + *start = st * bf; + *end = en * bf; // - ( bf - n_left ); + *inc = in * bf; + } + else // if ( bli_is_upper( uplo ) ) + { + dim_t st = n_iter - 2 * nt + tid; + + const dim_t in = 2 * ( nt - tid ) - 1; + + dim_t en = st + in + 1; + + #if 1 + // When nt exceeds half n_iter, some threads will only get one unit + // (NR columns, or an edge case) of work. This manifests as st being + // negative, and thus we need to move their start index to their other + // assigned unit in the positive index range. + if ( st < 0 ) st += in; + + // If the start index is *still* negative, which happens for some + // threads when nt exceeds n_iter, then manually assign this thread + // an empty index range. + if ( st < 0 ) { st = 0; en = 0; } + #else + if ( 0 <= st + in ) { st += in; } + else { st = 0; en = 0; } + #endif + + #if 0 + printf( "thread_range_snake_jr(): tid %d: sta end = %3d %3d %3d\n", + (int)tid, (int)(st), (int)(en), (int)(in) ); + #endif + + *start = st * bf; + *end = en * bf; + *inc = in * bf; + } +} +#endif diff --git a/frame/1m/packm/bli_packm_thrinfo.h b/frame/thread/old/bli_thread_range_snake.h similarity index 70% rename from frame/1m/packm/bli_packm_thrinfo.h rename to frame/thread/old/bli_thread_range_snake.h index 1ac7f88dfb..73fd4ae733 100644 --- a/frame/1m/packm/bli_packm_thrinfo.h +++ b/frame/thread/old/bli_thread_range_snake.h @@ -32,34 +32,22 @@ */ -// -// thrinfo_t macros specific to packm. -// - -/* -#define bli_packm_thread_my_iter( index, thread ) \ -\ - ( index % thread->n_way == thread->work_id % thread->n_way ) -*/ - -#define bli_packm_my_iter_rr( i, start, end, work_id, n_way ) \ -\ - ( i % n_way == work_id % n_way ) - -#define bli_packm_my_iter_sl( i, start, end, work_id, n_way ) \ -\ - ( start <= i && i < end ) - -// Define a general-purpose version of bli_packm_my_iter() whose definition -// depends on whether slab or round-robin partitioning was requested at -// configure-time. -#ifdef BLIS_ENABLE_JRIR_SLAB - - #define bli_packm_my_iter bli_packm_my_iter_sl - -#else // BLIS_ENABLE_JRIR_RR - - #define bli_packm_my_iter bli_packm_my_iter_rr - +#ifndef BLIS_THREAD_RANGE_SNAKE_H +#define BLIS_THREAD_RANGE_SNAKE_H + +#if 0 +void bli_thread_range_snake_jr + ( + const thrinfo_t* thread, + doff_t diagoff, + uplo_t uplo, + dim_t n, + dim_t bf, + bool handle_edge_low, + dim_t* start, + dim_t* end, + dim_t* inc + ); #endif +#endif diff --git a/sandbox/gemmlike/bls_gemm_bp_var1.c b/sandbox/gemmlike/bls_gemm_bp_var1.c index 02f7458adf..b61140743f 100644 --- a/sandbox/gemmlike/bls_gemm_bp_var1.c +++ b/sandbox/gemmlike/bls_gemm_bp_var1.c @@ -344,11 +344,11 @@ void PASTECH2(bls_,ch,varname) \ \ /* Compute the addresses of the next micropanels of A and B. */ \ a2 = bli_gemm_get_next_a_upanel( a_ir, ps_a_use, 1 ); \ - if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) \ + if ( bli_is_last_iter_slrr( i, ir_end, ir_tid, ir_nt ) ) \ { \ a2 = a_ic_use; \ b2 = bli_gemm_get_next_b_upanel( b_jr, ps_b_use, 1 ); \ - if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) \ + if ( bli_is_last_iter_slrr( j, jr_end, jr_tid, jr_nt ) ) \ b2 = b_pc_use; \ } \ \ diff --git a/sandbox/gemmlike/bls_l3_packm_var1.c b/sandbox/gemmlike/bls_l3_packm_var1.c index 7c2c4e9a90..b37d34cce3 100644 --- a/sandbox/gemmlike/bls_l3_packm_var1.c +++ b/sandbox/gemmlike/bls_l3_packm_var1.c @@ -131,10 +131,10 @@ void PASTECH2(bls_,ch,varname) \ dim_t it_start, it_end, it_inc; \ \ /* Determine the thread range and increment using the current thread's - packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir() + packm thrinfo_t node. NOTE: The definition of bli_thread_range_slrr() will depend on whether slab or round-robin partitioning was requested at configure-time. */ \ - bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); \ + bli_thread_range_slrr( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); \ \ /* Iterate over every logical micropanel in the source matrix. */ \ for ( ic = ic0, it = 0; it < n_iter; \ @@ -147,10 +147,10 @@ void PASTECH2(bls_,ch,varname) \ ctype* restrict c_use = c_begin; \ ctype* restrict p_use = p_begin; \ \ - /* The definition of bli_packm_my_iter() will depend on whether slab + /* The definition of bli_is_my_iter() will depend on whether slab or round-robin partitioning was requested at configure-time. (The default is slab.) */ \ - if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) \ + if ( bli_is_my_iter( it, it_start, it_end, tid, nt ) ) \ { \ PASTECH2(bls_,ch,packm_cxk) \ ( \ diff --git a/sandbox/gemmlike/bls_l3_packm_var2.c b/sandbox/gemmlike/bls_l3_packm_var2.c index 94ee0efcd8..b3efbbc28f 100644 --- a/sandbox/gemmlike/bls_l3_packm_var2.c +++ b/sandbox/gemmlike/bls_l3_packm_var2.c @@ -131,10 +131,10 @@ void PASTECH2(bls_,ch,varname) \ dim_t it_start, it_end, it_inc; \ \ /* Determine the thread range and increment using the current thread's - packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir() + packm thrinfo_t node. NOTE: The definition of bli_thread_range_slrr() will depend on whether slab or round-robin partitioning was requested at configure-time. */ \ - bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); \ + bli_thread_range_slrr( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); \ \ /* Iterate over every logical micropanel in the source matrix. */ \ for ( ic = ic0, it = 0; it < n_iter; \ @@ -147,10 +147,10 @@ void PASTECH2(bls_,ch,varname) \ ctype* restrict c_use = c_begin; \ ctype* restrict p_use = p_begin; \ \ - /* The definition of bli_packm_my_iter() will depend on whether slab + /* The definition of bli_is_my_iter() will depend on whether slab or round-robin partitioning was requested at configure-time. (The default is slab.) */ \ - if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) \ + if ( bli_is_my_iter( it, it_start, it_end, tid, nt ) ) \ { \ /* NOTE: We assume here that kappa = 1 and therefore ignore it. If we're wrong, this will get someone's attention. */ \ diff --git a/testsuite/src/test_libblis.c b/testsuite/src/test_libblis.c index 67fd384f46..7773d6cf0a 100644 --- a/testsuite/src/test_libblis.c +++ b/testsuite/src/test_libblis.c @@ -809,7 +809,7 @@ void libblis_test_output_params_struct( FILE* os, test_params_t* params ) char impl_str[32]; char def_impl_set_str[32]; char def_impl_unset_str[32]; - char jrir_str[16]; + char jrir_str[32]; const bool has_openmp = bli_info_get_enable_openmp(); const bool has_pthreads = bli_info_get_enable_pthreads(); @@ -844,8 +844,9 @@ void libblis_test_output_params_struct( FILE* os, test_params_t* params ) else sprintf( def_impl_set_str, "single" ); // Describe the status of jrir thread partitioning. - if ( bli_info_get_thread_part_jrir_slab() ) sprintf( jrir_str, "slab" ); - else /*bli_info_get_thread_part_jrir_rr()*/ sprintf( jrir_str, "round-robin" ); + if ( bli_info_get_thread_jrir_slab() ) sprintf( jrir_str, "slab" ); + else if ( bli_info_get_thread_jrir_rr() ) sprintf( jrir_str, "round-robin" ); + else /*bli_info_get_thread_jrir_tlb()*/ sprintf( jrir_str, "tile-level (slab)" ); char nt_str[16]; char jc_nt_str[16]; diff --git a/testsuite/src/test_trmm.c b/testsuite/src/test_trmm.c index 0504b33158..497ecf97ea 100644 --- a/testsuite/src/test_trmm.c +++ b/testsuite/src/test_trmm.c @@ -271,7 +271,10 @@ void libblis_test_trmm_impl switch ( iface ) { case BLIS_TEST_SEQ_FRONT_END: +//bli_printm( "a", a, "%5.2f", "" ); +//bli_printm( "b", b, "%5.2f", "" ); bli_trmm( side, alpha, a, b ); +//bli_printm( "b after", b, "%5.2f", "" ); break; default: