Skip to content

Commit

Permalink
Merge pull request #54 from NACLab/dev
Browse files Browse the repository at this point in the history
fixed bug in matrix normalizer for L2
  • Loading branch information
ago109 authored Jun 27, 2024
2 parents 8c9775d + 6098c12 commit 8e2e66d
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions ngclearn/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,14 @@ def normalize_matrix(M, wnorm, order=1, axis=0, scale=1.):
a normalized value matrix
"""
if order == 2: ## denominator is L2 norm
wOrdSum = jnp.maximum(jnp.sum(jnp.square(M), axis=axis, keepdims=True), 1e-8)
wOrdSum = jnp.maximum(jnp.sqrt(jnp.sum(jnp.square(M), axis=axis, keepdims=True)), 1e-8)
else: ## denominator is L1 norm
wOrdSum = jnp.maximum(jnp.sum(jnp.abs(M), axis=axis, keepdims=True), 1e-8)
m = (wOrdSum == 0.).astype(dtype=jnp.float32)
wOrdSum = wOrdSum * (1. - m) + m #wAbsSum[wAbsSum == 0.] = 1.
# _M = M * (wnorm/wOrdSum)
dM = ((wnorm/wOrdSum) - 1.) * M
_M = M + dM * scale
_M = M * (wnorm/wOrdSum)
#dM = ((wnorm/wOrdSum) - 1.) * M
#_M = M + dM * scale
return _M

@jit
Expand Down

0 comments on commit 8e2e66d

Please sign in to comment.