Skip to content

Commit

Permalink
added lower-triangle/upper-triangle masking for weight dist
Browse files Browse the repository at this point in the history
  • Loading branch information
ago109 committed Jun 28, 2024
1 parent 16588fa commit 1acaaae
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion ngclearn/utils/weight_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,14 @@ def initialize_params(dkey, init_kernel, shape, use_numpy=False):
dkey: PRNG key to control determinism of this routine
init_kernel: dictionary specifying the distribution type and its
parameters (default: `uniform` dist w/ `amin=0.02`, `amax=0.8`)
parameters (default: `uniform` dist w/ `amin=0.02`, `amax=0.8`) --
note that kernel dictionary may contain "post-processing" arguments
that can be "stacked" on top of the base matrix, for example, you
can pass in a dictionary:
{"dist": "uniform", "hollow": True, "lower_triangle": True} which
will create unit-uniform value matrix with upper triangle and main
diagonal values masked to zero (lower-triangle masking applied after
hollow matrix masking)
:Note: Currently supported distribution (dist) kernel schemes include:
"constant" (value);
Expand All @@ -115,6 +122,8 @@ def initialize_params(dkey, init_kernel, shape, use_numpy=False):
while currently supported post-processing keyword arguments include:
"amin" (clip weights values to be >= amin);
"amax" (clip weights values to be <= amin);
"lower_triangle" (extract lower triangle of params, set rest to 0);
"upper_triangle" (extract upper triangle of params, set rest to 0);
"hollow" (zero out values along main diagonal);
"eye" (zero out off-diagonal values);
"n_row_active" (keep only n random rows non-masked/zero);
Expand Down Expand Up @@ -180,6 +189,8 @@ def initialize_params(dkey, init_kernel, shape, use_numpy=False):
## check for any additional distribution post-processing kwargs (e.g., clipping)
clip_min = _init_kernel.get("amin")
clip_max = _init_kernel.get("amax")
lower_triangle = init_kernel("lower_triangle")
upper_triangle = init_kernel("upper_triangle")
is_hollow = _init_kernel.get("hollow", False)
is_eye = _init_kernel.get("eye", False)
n_row_active = _init_kernel.get("n_row_active", None)
Expand All @@ -195,6 +206,12 @@ def initialize_params(dkey, init_kernel, shape, use_numpy=False):
params = np.minimum(params, clip_max)
else:
params = jnp.minimum(params, clip_max)
if lower_triangle: ## extract lower triangle of params matrix
ltri_params = jax.numpy.tril(params.shape[0])
params = ltri_params
if upper_triangle: ## extract upper triangle of params matrix
ltri_params = jax.numpy.triu(params.shape[0])
params = ltri_params
if is_hollow: ## apply a hollow mask
if use_numpy:
params = (1. - np.eye(N=shape[0], M=shape[1])) * params
Expand Down

0 comments on commit 1acaaae

Please sign in to comment.