Skip to content

Commit 5badbf3

Browse files
committed
update
1 parent ff77971 commit 5badbf3

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

nlp_class2/pmi.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,13 @@ def remove_punctuation_3(s):
152152

153153

154154
# PMI(w, c) = #(w, c) / #(w) / p(c)
155-
pmi = wc_counts / wc_counts.sum(axis=1) / c_probs
155+
# pmi = wc_counts / wc_counts.sum(axis=1) / c_probs # works only if numpy arrays
156+
pmi = wc_counts.multiply(1.0 / wc_counts.sum(axis=1) / c_probs).tocsr()
157+
# this operation changes it to a coo_matrix
158+
# which doesn't have functions we need, e.g log1p()
159+
# so convert it back to a csr
156160
print("type(pmi):", type(pmi))
157-
logX = np.log(pmi.A + 1)
161+
logX = pmi.log1p() # would be logX = np.log(pmi.A + 1) in numpy
158162
print("type(logX):", type(logX))
159163
logX[logX < 0] = 0
160164

@@ -180,7 +184,9 @@ def remove_punctuation_3(s):
180184
for epoch in range(10):
181185
print("epoch:", epoch)
182186
delta = W.dot(U.T) + b.reshape(V, 1) + c.reshape(1, V) + mu - logX
183-
cost = ( delta * delta ).sum()
187+
# cost = ( delta * delta ).sum()
188+
cost = np.multiply(delta, delta).sum()
189+
# * behaves differently if delta is a "matrix" object vs "array" object
184190
costs.append(cost)
185191

186192
### partially vectorized updates ###

0 commit comments

Comments
 (0)