Skip to content

Commit e43546c

Browse files
Merge pull request #323 from rsagroup/combine-rescale-threshold
threshold for rdm.combine.rescale
2 parents b802f17 + d9f26da commit e43546c

File tree

2 files changed

+31
-5
lines changed

2 files changed

+31
-5
lines changed

src/rsatoolbox/rdm/combine.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def pdescs(rdms, descriptor):
9696
)
9797

9898

99-
def rescale(rdms, method: str = 'evidence'):
99+
def rescale(rdms, method: str = 'evidence', threshold=1e-8):
100100
"""Bring RDMs closer together
101101
102102
Iteratively scales RDMs based on pairs in-common.
@@ -105,11 +105,15 @@ def rescale(rdms, method: str = 'evidence'):
105105
Args:
106106
method (str, optional): One of 'evidence', 'setsize' or
107107
'simple'. Defaults to 'evidence'.
108+
threshold (float): Stop iterating when the sum of squares
109+
difference between iterations is smaller than this value.
110+
A smaller value means more iterations, but the algorithm
111+
may not always converge.
108112
109113
Returns:
110114
RDMs: RDMs object with the aligned RDMs
111115
"""
112-
aligned, weights = _rescale(rdms.dissimilarities, method)
116+
aligned, weights = _rescale(rdms.dissimilarities, method, threshold)
113117
rdm_descriptors = deepcopy(rdms.rdm_descriptors)
114118
if weights is not None:
115119
rdm_descriptors['rescalingWeights'] = weights
@@ -166,7 +170,7 @@ def _scale(vectors: ndarray) -> ndarray:
166170
return vectors / sqrt(_ss(vectors))
167171

168172

169-
def _rescale(dissim: ndarray, method: str) -> Tuple[ndarray, ndarray]:
173+
def _rescale(dissim: ndarray, method: str, threshold=1e-8) -> Tuple[ndarray, ndarray]:
170174
"""Rescale RDM vectors
171175
172176
See :meth:`rsatoolbox.rdm.combine.rescale`
@@ -191,7 +195,7 @@ def _rescale(dissim: ndarray, method: str) -> Tuple[ndarray, ndarray]:
191195

192196
current_estimate = _scale(_mean(dissim))
193197
prev_estimate = np.full([n_conds, ], -inf)
194-
while _ss(current_estimate - prev_estimate) > 1e-8:
198+
while _ss(current_estimate - prev_estimate) > threshold:
195199
prev_estimate = current_estimate.copy()
196200
tiled_estimate = np.tile(current_estimate, [n_rdms, 1])
197201
tiled_estimate[np.isnan(dissim)] = nan

tests/test_rdms_combine.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"""
33
# pylint: disable=import-outside-toplevel, no-self-use
44
from unittest import TestCase
5-
from numpy import array, nan, isnan
5+
from numpy import array, nan, isnan, mean, abs as _abs, diff
66
from numpy.testing import assert_almost_equal, assert_array_equal
77
from scipy.stats import pearsonr
88

@@ -70,6 +70,28 @@ def test_rescale_setsize(self):
7070
decimal=4
7171
)
7272

73+
def test_rescale_threshold(self):
74+
"""The rescale function bring the RDMs as close together as possible
75+
"""
76+
from rsatoolbox.rdm.rdms import RDMs
77+
from rsatoolbox.rdm.combine import rescale
78+
partial_rdms = RDMs(
79+
dissimilarities=array([
80+
[ 1, 2, nan, 3, nan, nan],
81+
[nan, nan, nan, 4, 5, nan],
82+
])
83+
)
84+
## high threshold, fewer iterations, substantial difference remaining
85+
rescaled_rdms = rescale(partial_rdms, method='simple', threshold=10)
86+
common_pair = rescaled_rdms.dissimilarities[:, 3]
87+
rel_diff = _abs(diff(common_pair)/mean(common_pair))
88+
self.assertGreater(rel_diff[0], 0.1)
89+
## low threshold, more iterations, difference small
90+
rescaled_rdms = rescale(partial_rdms, method='simple', threshold=0.00001)
91+
common_pair = rescaled_rdms.dissimilarities[:, 3]
92+
rel_diff = _abs(diff(common_pair)/mean(common_pair))
93+
self.assertLess(rel_diff[0], 0.01)
94+
7395
def test_mean_no_weights(self):
7496
"""RDMs.mean() returns an RDMs with the nan omitted mean of the rdms
7597
"""

0 commit comments

Comments
 (0)