Skip to content

Commit d0b35e3

Browse files
author
Jasper van den Bosch
committed
rank_transform() can deal with unknown dissim measure
1 parent 12da07e commit d0b35e3

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

src/rsatoolbox/rdm/transform.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
# -*- coding: utf-8 -*-
33
""" transforms, which can be applied to RDMs
44
"""
5-
5+
from __future__ import annotations
66
from copy import deepcopy
77
import numpy as np
88
from scipy.stats import rankdata
99
from .rdms import RDMs
1010

1111

12-
def rank_transform(rdms, method='average'):
12+
def rank_transform(rdms: RDMs, method='average'):
1313
""" applies a rank_transform and generates a new RDMs object
1414
This assigns a rank to each dissimilarity estimate in the RDM,
1515
deals with rank ties and saves ranks as new dissimilarity estimates.
@@ -30,9 +30,9 @@ def rank_transform(rdms, method='average'):
3030
dissimilarities = rdms.get_vectors()
3131
dissimilarities = np.array([rankdata(dissimilarities[i], method=method)
3232
for i in range(rdms.n_rdm)])
33-
measure = rdms.dissimilarity_measure
34-
if not measure[-7:] == '(ranks)':
35-
measure = measure + ' (ranks)'
33+
measure = rdms.dissimilarity_measure or ''
34+
if '(ranks)' not in measure:
35+
measure = (measure + ' (ranks)').strip()
3636
rdms_new = RDMs(dissimilarities,
3737
dissimilarity_measure=measure,
3838
descriptors=deepcopy(rdms.descriptors),
@@ -103,8 +103,9 @@ def transform(rdms, fun):
103103
"""
104104
dissimilarities = rdms.get_vectors()
105105
dissimilarities = fun(dissimilarities)
106+
meas = 'transformed ' + rdms.dissimilarity_measure
106107
rdms_new = RDMs(dissimilarities,
107-
dissimilarity_measure='transformed ' + rdms.dissimilarity_measure,
108+
dissimilarity_measure=meas,
108109
descriptors=deepcopy(rdms.descriptors),
109110
rdm_descriptors=deepcopy(rdms.rdm_descriptors),
110111
pattern_descriptors=deepcopy(rdms.pattern_descriptors))

tests/test_rdm.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,12 @@ def test_rank_transform(self):
246246
self.assertEqual(rank_rdm.n_cond, rdms.n_cond)
247247
self.assertEqual(rank_rdm.dissimilarity_measure, 'Euclidean (ranks)')
248248

249+
def test_rank_transform_unknown_measure(self):
250+
from rsatoolbox.rdm import rank_transform
251+
rdms = rsr.RDMs(dissimilarities=np.zeros((8, 10)))
252+
rank_rdm = rank_transform(rdms)
253+
self.assertEqual(rank_rdm.dissimilarity_measure, '(ranks)')
254+
249255
def test_sqrt_transform(self):
250256
from rsatoolbox.rdm import sqrt_transform
251257
dis = np.zeros((8, 10))
@@ -463,17 +469,19 @@ def test_copy(self):
463469
)
464470
)
465471
copy = orig.copy()
466-
## We don't want a reference:
472+
# We don't want a reference:
467473
self.assertIsNot(copy, orig)
468474
self.assertIsNot(copy.dissimilarities, orig.dissimilarities)
469475
self.assertIsNot(
470476
copy.pattern_descriptors.get('order'),
471477
orig.pattern_descriptors.get('order')
472478
)
473-
## But check that attributes are equal
479+
# But check that attributes are equal
474480
assert_array_equal(copy.dissimilarities, orig.dissimilarities)
475-
self.assertEqual(copy.dissimilarity_measure,
476-
orig.dissimilarity_measure)
481+
self.assertEqual(
482+
copy.dissimilarity_measure,
483+
orig.dissimilarity_measure
484+
)
477485
self.assertEqual(copy.descriptors, orig.descriptors)
478486
self.assertEqual(copy.rdm_descriptors, orig.rdm_descriptors)
479487
assert_array_equal(

0 commit comments

Comments
 (0)