diff --git a/ngclearn/utils/model_utils.py b/ngclearn/utils/model_utils.py index d3a3a3d71..59bc8a32a 100755 --- a/ngclearn/utils/model_utils.py +++ b/ngclearn/utils/model_utils.py @@ -134,14 +134,14 @@ def normalize_matrix(M, wnorm, order=1, axis=0, scale=1.): a normalized value matrix """ if order == 2: ## denominator is L2 norm - wOrdSum = jnp.maximum(jnp.sum(jnp.square(M), axis=axis, keepdims=True), 1e-8) + wOrdSum = jnp.maximum(jnp.sqrt(jnp.sum(jnp.square(M), axis=axis, keepdims=True)), 1e-8) else: ## denominator is L1 norm wOrdSum = jnp.maximum(jnp.sum(jnp.abs(M), axis=axis, keepdims=True), 1e-8) m = (wOrdSum == 0.).astype(dtype=jnp.float32) wOrdSum = wOrdSum * (1. - m) + m #wAbsSum[wAbsSum == 0.] = 1. - # _M = M * (wnorm/wOrdSum) - dM = ((wnorm/wOrdSum) - 1.) * M - _M = M + dM * scale + _M = M * (wnorm/wOrdSum) + #dM = ((wnorm/wOrdSum) - 1.) * M + #_M = M + dM * scale return _M @jit