Skip to content

Commit 2361bc5

Browse files
Merge pull request #298 from rsagroup/rdms-to-pandas
RDMs to Pandas DataFrame
2 parents 8182ac0 + d0b35e3 commit 2361bc5

File tree

5 files changed

+112
-10
lines changed

5 files changed

+112
-10
lines changed

Diff for: src/rsatoolbox/io/pandas.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""Conversions from rsatoolbox classes to pandas table objects
2+
"""
3+
from __future__ import annotations
4+
from typing import TYPE_CHECKING
5+
from pandas import DataFrame
6+
import numpy
7+
from numpy import asarray
8+
if TYPE_CHECKING:
9+
from rsatoolbox.rdm.rdms import RDMs
10+
11+
12+
def rdms_to_df(rdms: RDMs) -> DataFrame:
13+
"""Create DataFrame representation of the RDMs object
14+
15+
A column for:
16+
- dissimilarity
17+
- each rdm descriptor
18+
- two for each pattern descriptor, suffixed by _1 and _2 respectively
19+
20+
Multiple RDMs are stacked row-wise.
21+
See also the `RDMs.to_df()` method which calls this function
22+
23+
Args:
24+
rdms (RDMs): the object to convert
25+
26+
Returns:
27+
DataFrame: long-form pandas DataFrame with
28+
dissimilarities and descriptors.
29+
"""
30+
n_rdms, n_pairs = rdms.dissimilarities.shape
31+
cols = dict(dissimilarity=rdms.dissimilarities.ravel())
32+
for dname, dvals in rdms.rdm_descriptors.items():
33+
# rename the default index desc as that has special meaning in df
34+
cname = 'rdm_index' if dname == 'index' else dname
35+
cols[cname] = numpy.repeat(dvals, n_pairs)
36+
for dname, dvals in rdms.pattern_descriptors.items():
37+
ix = numpy.triu_indices(len(dvals), 1)
38+
# rename the default index desc as that has special meaning in df
39+
cname = 'pattern_index' if dname == 'index' else dname
40+
for p in (0, 1):
41+
cols[f'{cname}_{p+1}'] = numpy.tile(asarray(dvals)[ix[p]], n_rdms)
42+
return DataFrame(cols)

Diff for: src/rsatoolbox/rdm/rdms.py

+11
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from copy import deepcopy
1111
from collections.abc import Iterable
1212
import numpy as np
13+
from rsatoolbox.io.pandas import rdms_to_df
1314
from rsatoolbox.rdm.combine import _mean
1415
from rsatoolbox.util.rdm_utils import batch_to_vectors
1516
from rsatoolbox.util.rdm_utils import batch_to_matrices
@@ -400,6 +401,16 @@ def to_dict(self):
400401
rdm_dict['dissimilarity_measure'] = self.dissimilarity_measure
401402
return rdm_dict
402403

404+
def to_df(self):
405+
"""Return a new long-form pandas DataFrame representing this RDM
406+
407+
See `rsatoolbox.io.pandas.rdms_to_df` for details
408+
409+
Returns:
410+
pandas.DataFrame: The DataFrame for this RDMs object
411+
"""
412+
return rdms_to_df(self)
413+
403414
def reorder(self, new_order):
404415
"""Reorder the patterns according to the index in new_order
405416

Diff for: src/rsatoolbox/rdm/transform.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
# -*- coding: utf-8 -*-
33
""" transforms, which can be applied to RDMs
44
"""
5-
5+
from __future__ import annotations
66
from copy import deepcopy
77
import numpy as np
88
from scipy.stats import rankdata
99
from .rdms import RDMs
1010

1111

12-
def rank_transform(rdms, method='average'):
12+
def rank_transform(rdms: RDMs, method='average'):
1313
""" applies a rank_transform and generates a new RDMs object
1414
This assigns a rank to each dissimilarity estimate in the RDM,
1515
deals with rank ties and saves ranks as new dissimilarity estimates.
@@ -30,9 +30,9 @@ def rank_transform(rdms, method='average'):
3030
dissimilarities = rdms.get_vectors()
3131
dissimilarities = np.array([rankdata(dissimilarities[i], method=method)
3232
for i in range(rdms.n_rdm)])
33-
measure = rdms.dissimilarity_measure
34-
if not measure[-7:] == '(ranks)':
35-
measure = measure + ' (ranks)'
33+
measure = rdms.dissimilarity_measure or ''
34+
if '(ranks)' not in measure:
35+
measure = (measure + ' (ranks)').strip()
3636
rdms_new = RDMs(dissimilarities,
3737
dissimilarity_measure=measure,
3838
descriptors=deepcopy(rdms.descriptors),
@@ -103,8 +103,9 @@ def transform(rdms, fun):
103103
"""
104104
dissimilarities = rdms.get_vectors()
105105
dissimilarities = fun(dissimilarities)
106+
meas = 'transformed ' + rdms.dissimilarity_measure
106107
rdms_new = RDMs(dissimilarities,
107-
dissimilarity_measure='transformed ' + rdms.dissimilarity_measure,
108+
dissimilarity_measure=meas,
108109
descriptors=deepcopy(rdms.descriptors),
109110
rdm_descriptors=deepcopy(rdms.rdm_descriptors),
110111
pattern_descriptors=deepcopy(rdms.pattern_descriptors))

Diff for: tests/test_rdm.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,12 @@ def test_rank_transform(self):
246246
self.assertEqual(rank_rdm.n_cond, rdms.n_cond)
247247
self.assertEqual(rank_rdm.dissimilarity_measure, 'Euclidean (ranks)')
248248

249+
def test_rank_transform_unknown_measure(self):
250+
from rsatoolbox.rdm import rank_transform
251+
rdms = rsr.RDMs(dissimilarities=np.zeros((8, 10)))
252+
rank_rdm = rank_transform(rdms)
253+
self.assertEqual(rank_rdm.dissimilarity_measure, '(ranks)')
254+
249255
def test_sqrt_transform(self):
250256
from rsatoolbox.rdm import sqrt_transform
251257
dis = np.zeros((8, 10))
@@ -463,17 +469,19 @@ def test_copy(self):
463469
)
464470
)
465471
copy = orig.copy()
466-
## We don't want a reference:
472+
# We don't want a reference:
467473
self.assertIsNot(copy, orig)
468474
self.assertIsNot(copy.dissimilarities, orig.dissimilarities)
469475
self.assertIsNot(
470476
copy.pattern_descriptors.get('order'),
471477
orig.pattern_descriptors.get('order')
472478
)
473-
## But check that attributes are equal
479+
# But check that attributes are equal
474480
assert_array_equal(copy.dissimilarities, orig.dissimilarities)
475-
self.assertEqual(copy.dissimilarity_measure,
476-
orig.dissimilarity_measure)
481+
self.assertEqual(
482+
copy.dissimilarity_measure,
483+
orig.dissimilarity_measure
484+
)
477485
self.assertEqual(copy.descriptors, orig.descriptors)
478486
self.assertEqual(copy.rdm_descriptors, orig.rdm_descriptors)
479487
assert_array_equal(

Diff for: tests/test_rdms_pandas.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from __future__ import annotations
2+
from unittest import TestCase
3+
from typing import TYPE_CHECKING, Union, List
4+
from numpy.testing import assert_array_equal
5+
import numpy
6+
from pandas import Series, DataFrame
7+
if TYPE_CHECKING:
8+
from numpy.typing import NDArray
9+
10+
11+
class RdmsToPandasTests(TestCase):
12+
13+
def assertValuesEqual(self,
14+
actual: Series,
15+
expected: Union[NDArray, List]):
16+
assert_array_equal(numpy.asarray(actual.values), expected)
17+
18+
def test_to_df(self):
19+
"""Convert an RDMs object to a pandas DataFrame
20+
21+
Default is long form; multiple rdms are stacked row-wise.
22+
"""
23+
from rsatoolbox.rdm.rdms import RDMs
24+
dissimilarities = numpy.random.rand(2, 6)
25+
rdms = RDMs(
26+
dissimilarities,
27+
rdm_descriptors=dict(xy=[c for c in 'xy']),
28+
pattern_descriptors=dict(abcd=numpy.asarray([c for c in 'abcd']))
29+
)
30+
df = rdms.to_df()
31+
self.assertIsInstance(df, DataFrame)
32+
self.assertEqual(len(df.columns), 7)
33+
self.assertValuesEqual(df.dissimilarity, dissimilarities.ravel())
34+
self.assertValuesEqual(df['rdm_index'], ([0]*6) + ([1]*6))
35+
self.assertValuesEqual(df['xy'], (['x']*6) + (['y']*6))
36+
self.assertValuesEqual(df['pattern_index_1'],
37+
([0]*3 + [1]*2 + [2]*1)*2)
38+
self.assertValuesEqual(df['pattern_index_2'], [1, 2, 3, 2, 3, 3]*2)
39+
self.assertValuesEqual(df['abcd_1'], [c for c in 'aaabbc']*2)
40+
self.assertValuesEqual(df['abcd_2'], [c for c in 'bcdcdd']*2)

0 commit comments

Comments
 (0)