File tree 1 file changed +9
-3
lines changed
1 file changed +9
-3
lines changed Original file line number Diff line number Diff line change @@ -152,9 +152,13 @@ def remove_punctuation_3(s):
152
152
153
153
154
154
# PMI(w, c) = #(w, c) / #(w) / p(c)
155
- pmi = wc_counts / wc_counts .sum (axis = 1 ) / c_probs
155
+ # pmi = wc_counts / wc_counts.sum(axis=1) / c_probs # works only if numpy arrays
156
+ pmi = wc_counts .multiply (1.0 / wc_counts .sum (axis = 1 ) / c_probs ).tocsr ()
157
+ # this operation changes it to a coo_matrix
158
+ # which doesn't have functions we need, e.g log1p()
159
+ # so convert it back to a csr
156
160
print ("type(pmi):" , type (pmi ))
157
- logX = np .log (pmi .A + 1 )
161
+ logX = pmi . log1p () # would be logX = np.log(pmi.A + 1) in numpy
158
162
print ("type(logX):" , type (logX ))
159
163
logX [logX < 0 ] = 0
160
164
@@ -180,7 +184,9 @@ def remove_punctuation_3(s):
180
184
for epoch in range (10 ):
181
185
print ("epoch:" , epoch )
182
186
delta = W .dot (U .T ) + b .reshape (V , 1 ) + c .reshape (1 , V ) + mu - logX
183
- cost = ( delta * delta ).sum ()
187
+ # cost = ( delta * delta ).sum()
188
+ cost = np .multiply (delta , delta ).sum ()
189
+ # * behaves differently if delta is a "matrix" object vs "array" object
184
190
costs .append (cost )
185
191
186
192
### partially vectorized updates ###
You can’t perform that action at this time.
0 commit comments