From 6098c12fcbb2eacb98d0dfd4947c9aab773dd31d Mon Sep 17 00:00:00 2001 From: ago109 Date: Thu, 27 Jun 2024 12:08:12 -0400 Subject: [PATCH] fixed bug in matrix normalizer for L2 --- ngclearn/utils/model_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ngclearn/utils/model_utils.py b/ngclearn/utils/model_utils.py index d3a3a3d71..59bc8a32a 100755 --- a/ngclearn/utils/model_utils.py +++ b/ngclearn/utils/model_utils.py @@ -134,14 +134,14 @@ def normalize_matrix(M, wnorm, order=1, axis=0, scale=1.): a normalized value matrix """ if order == 2: ## denominator is L2 norm - wOrdSum = jnp.maximum(jnp.sum(jnp.square(M), axis=axis, keepdims=True), 1e-8) + wOrdSum = jnp.maximum(jnp.sqrt(jnp.sum(jnp.square(M), axis=axis, keepdims=True)), 1e-8) else: ## denominator is L1 norm wOrdSum = jnp.maximum(jnp.sum(jnp.abs(M), axis=axis, keepdims=True), 1e-8) m = (wOrdSum == 0.).astype(dtype=jnp.float32) wOrdSum = wOrdSum * (1. - m) + m #wAbsSum[wAbsSum == 0.] = 1. - # _M = M * (wnorm/wOrdSum) - dM = ((wnorm/wOrdSum) - 1.) * M - _M = M + dM * scale + _M = M * (wnorm/wOrdSum) + #dM = ((wnorm/wOrdSum) - 1.) * M + #_M = M + dM * scale return _M @jit