From 1acaaae2cf7d86d4f754d848c69f347ccc7920a9 Mon Sep 17 00:00:00 2001 From: ago109 Date: Fri, 28 Jun 2024 14:38:49 -0400 Subject: [PATCH] added lower-triangle/upper-triangle masking for weight dist --- ngclearn/utils/weight_distribution.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/ngclearn/utils/weight_distribution.py b/ngclearn/utils/weight_distribution.py index 9c3b3c89..562a8b61 100755 --- a/ngclearn/utils/weight_distribution.py +++ b/ngclearn/utils/weight_distribution.py @@ -103,7 +103,14 @@ def initialize_params(dkey, init_kernel, shape, use_numpy=False): dkey: PRNG key to control determinism of this routine init_kernel: dictionary specifying the distribution type and its - parameters (default: `uniform` dist w/ `amin=0.02`, `amax=0.8`) + parameters (default: `uniform` dist w/ `amin=0.02`, `amax=0.8`) -- + note that kernel dictionary may contain "post-processing" arguments + that can be "stacked" on top of the base matrix, for example, you + can pass in a dictionary: + {"dist": "uniform", "hollow": True, "lower_triangle": True} which + will create unit-uniform value matrix with upper triangle and main + diagonal values masked to zero (lower-triangle masking applied after + hollow matrix masking) :Note: Currently supported distribution (dist) kernel schemes include: "constant" (value); @@ -115,6 +122,8 @@ def initialize_params(dkey, init_kernel, shape, use_numpy=False): while currently supported post-processing keyword arguments include: "amin" (clip weights values to be >= amin); "amax" (clip weights values to be <= amin); + "lower_triangle" (extract lower triangle of params, set rest to 0); + "upper_triangle" (extract upper triangle of params, set rest to 0); "hollow" (zero out values along main diagonal); "eye" (zero out off-diagonal values); "n_row_active" (keep only n random rows non-masked/zero); @@ -180,6 +189,8 @@ def initialize_params(dkey, init_kernel, shape, use_numpy=False): ## check for any additional distribution post-processing kwargs (e.g., clipping) clip_min = _init_kernel.get("amin") clip_max = _init_kernel.get("amax") + lower_triangle = init_kernel("lower_triangle") + upper_triangle = init_kernel("upper_triangle") is_hollow = _init_kernel.get("hollow", False) is_eye = _init_kernel.get("eye", False) n_row_active = _init_kernel.get("n_row_active", None) @@ -195,6 +206,12 @@ def initialize_params(dkey, init_kernel, shape, use_numpy=False): params = np.minimum(params, clip_max) else: params = jnp.minimum(params, clip_max) + if lower_triangle: ## extract lower triangle of params matrix + ltri_params = jax.numpy.tril(params.shape[0]) + params = ltri_params + if upper_triangle: ## extract upper triangle of params matrix + ltri_params = jax.numpy.triu(params.shape[0]) + params = ltri_params if is_hollow: ## apply a hollow mask if use_numpy: params = (1. - np.eye(N=shape[0], M=shape[1])) * params