Skip to content

Commit 8e2e66d

Browse files
authored
Merge pull request #54 from NACLab/dev
fixed bug in matrix normalizer for L2
2 parents 8c9775d + 6098c12 commit 8e2e66d

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

ngclearn/utils/model_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,14 @@ def normalize_matrix(M, wnorm, order=1, axis=0, scale=1.):
134134
a normalized value matrix
135135
"""
136136
if order == 2: ## denominator is L2 norm
137-
wOrdSum = jnp.maximum(jnp.sum(jnp.square(M), axis=axis, keepdims=True), 1e-8)
137+
wOrdSum = jnp.maximum(jnp.sqrt(jnp.sum(jnp.square(M), axis=axis, keepdims=True)), 1e-8)
138138
else: ## denominator is L1 norm
139139
wOrdSum = jnp.maximum(jnp.sum(jnp.abs(M), axis=axis, keepdims=True), 1e-8)
140140
m = (wOrdSum == 0.).astype(dtype=jnp.float32)
141141
wOrdSum = wOrdSum * (1. - m) + m #wAbsSum[wAbsSum == 0.] = 1.
142-
# _M = M * (wnorm/wOrdSum)
143-
dM = ((wnorm/wOrdSum) - 1.) * M
144-
_M = M + dM * scale
142+
_M = M * (wnorm/wOrdSum)
143+
#dM = ((wnorm/wOrdSum) - 1.) * M
144+
#_M = M + dM * scale
145145
return _M
146146

147147
@jit

0 commit comments

Comments
 (0)