diff --git a/breze/learn/base.py b/breze/learn/base.py index 75017e9..eb483e6 100644 --- a/breze/learn/base.py +++ b/breze/learn/base.py @@ -246,27 +246,28 @@ def _make_loss_functions(self, mode=None, givens=None, return f_loss, f_d_loss - def _make_args(self, X, Z, imp_weight=None): + def _make_args(self, X, Z, imp_weight=None, n_cycles=False): batch_size = getattr(self, 'batch_size', None) if batch_size is None: X, Z = cast_array_to_local_type(X), cast_array_to_local_type(Z) + times = n_cycles if n_cycles else None if imp_weight is not None: imp_weight = cast_array_to_local_type(imp_weight) - data = itertools.repeat([X, Z, imp_weight]) + data = itertools.repeat([X, Z, imp_weight], times=times) else: - data = itertools.repeat([X, Z]) + data = itertools.repeat([X, Z], times=times) elif batch_size < 1: raise ValueError('need strictly positive batch size') else: if imp_weight is not None: data = iter_minibatches([X, Z, imp_weight], self.batch_size, - list(self.sample_dim) + [self.sample_dim[0]]) + list(self.sample_dim) + [self.sample_dim[0]], n_cycles=n_cycles) data = ((cast_array_to_local_type(x), cast_array_to_local_type(z), cast_array_to_local_type(w)) for x, z, w in data) else: data = iter_minibatches([X, Z], self.batch_size, - self.sample_dim) + self.sample_dim, n_cycles=n_cycles) data = ((cast_array_to_local_type(x), cast_array_to_local_type(z)) for x, z in data) @@ -373,16 +374,17 @@ def score(self, X, Z, imp_weight=None): l : scalar Score of the model. """ - X = cast_array_to_local_type(X) - Z = cast_array_to_local_type(Z) - if imp_weight is not None: - imp_weight = cast_array_to_local_type(imp_weight) if self.f_score is None: self.f_score = self._make_score_function( imp_weight=(imp_weight is not None)) - if imp_weight is None: - return self.f_score(X, Z) - return self.f_score(X, Z, imp_weight) + + score = 0 + sample_count = 0 + for arg in self._make_args(X, Z, imp_weight, n_cycles=1): + samples_in_batch = int(arg[0][0].shape[self.sample_dim[0]]) + score += self.f_score(*arg[0]) * samples_in_batch + sample_count += samples_in_batch + return score / sample_count class UnsupervisedModel(Model, BrezeWrapperBase): @@ -492,7 +494,7 @@ def fit(self, X, W=None): if i + 1 >= self.max_iter: break - def _make_args(self, X, W=None): + def _make_args(self, X, W=None, n_cycles=False): batch_size = getattr(self, 'batch_size', None) use_imp_weight = W is not None if self.use_imp_weight != use_imp_weight: @@ -507,11 +509,12 @@ def _make_args(self, X, W=None): sample_dim.append(sample_dim[0]) if batch_size is None: - data = itertools.repeat(item) + times = n_cycles if n_cycles else None + data = itertools.repeat(item, times=times) elif batch_size < 1: raise ValueError('need strictly positive batch size') else: - data = iter_minibatches(item, self.batch_size, sample_dim) + data = iter_minibatches(item, self.batch_size, sample_dim, n_cycles=n_cycles) if use_imp_weight: data = ((cast_array_to_local_type(x), cast_array_to_local_type(w)) for x, w in data) @@ -542,13 +545,16 @@ def score(self, X, W=None): l : scalar Score of the model. """ - X = cast_array_to_local_type(X) if self.f_score is None: self.f_score = self._make_score_function() - args = [X] if W is None else [X, W] - l = self.f_score(*args) - return l + score = 0 + sample_count = 0 + for arg in self._make_args(X, W, n_cycles=1): + samples_in_batch = int(arg[0][0].shape[self.sample_dim[0]]) + score += self.f_score(*arg[0]) * samples_in_batch + sample_count += samples_in_batch + return score / sample_count class TransformBrezeWrapperMixin(object):