Skip to content

Commit 12da07e

Browse files
author
Jasper van den Bosch
committed
rdm_to_df for pairs
1 parent 56c4716 commit 12da07e

File tree

2 files changed

+34
-22
lines changed

2 files changed

+34
-22
lines changed

src/rsatoolbox/io/pandas.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import TYPE_CHECKING
55
from pandas import DataFrame
66
import numpy
7+
from numpy import asarray
78
if TYPE_CHECKING:
89
from rsatoolbox.rdm.rdms import RDMs
910

@@ -13,8 +14,8 @@ def rdms_to_df(rdms: RDMs) -> DataFrame:
1314
1415
A column for:
1516
- dissimilarity
16-
- each pattern descriptor
1717
- each rdm descriptor
18+
- two for each pattern descriptor, suffixed by _1 and _2 respectively
1819
1920
Multiple RDMs are stacked row-wise.
2021
See also the `RDMs.to_df()` method which calls this function
@@ -26,16 +27,16 @@ def rdms_to_df(rdms: RDMs) -> DataFrame:
2627
DataFrame: long-form pandas DataFrame with
2728
dissimilarities and descriptors.
2829
"""
29-
n_rdms, n_conds = rdms.dissimilarities.shape
30+
n_rdms, n_pairs = rdms.dissimilarities.shape
3031
cols = dict(dissimilarity=rdms.dissimilarities.ravel())
3132
for dname, dvals in rdms.rdm_descriptors.items():
32-
# rename the default index descriptor as that has special meaning
33-
if dname == 'index':
34-
dname = 'rdm_index'
35-
cols[dname] = numpy.repeat(dvals, n_conds)
33+
# rename the default index desc as that has special meaning in df
34+
cname = 'rdm_index' if dname == 'index' else dname
35+
cols[cname] = numpy.repeat(dvals, n_pairs)
3636
for dname, dvals in rdms.pattern_descriptors.items():
37-
# rename the default index descriptor as that has special meaning
38-
if dname == 'index':
39-
dname = 'pattern_index'
40-
cols[dname] = numpy.tile(dvals, n_rdms)
37+
ix = numpy.triu_indices(len(dvals), 1)
38+
# rename the default index desc as that has special meaning in df
39+
cname = 'pattern_index' if dname == 'index' else dname
40+
for p in (0, 1):
41+
cols[f'{cname}_{p+1}'] = numpy.tile(asarray(dvals)[ix[p]], n_rdms)
4142
return DataFrame(cols)

tests/test_rdms_pandas.py

+23-12
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,40 @@
1+
from __future__ import annotations
12
from unittest import TestCase
3+
from typing import TYPE_CHECKING, Union, List
24
from numpy.testing import assert_array_equal
35
import numpy
4-
import pandas
6+
from pandas import Series, DataFrame
7+
if TYPE_CHECKING:
8+
from numpy.typing import NDArray
59

610

711
class RdmsToPandasTests(TestCase):
812

13+
def assertValuesEqual(self,
14+
actual: Series,
15+
expected: Union[NDArray, List]):
16+
assert_array_equal(numpy.asarray(actual.values), expected)
17+
918
def test_to_df(self):
1019
"""Convert an RDMs object to a pandas DataFrame
1120
1221
Default is long form; multiple rdms are stacked row-wise.
1322
"""
1423
from rsatoolbox.rdm.rdms import RDMs
15-
dissimilarities = numpy.random.rand(2, 3)
16-
conds = [c for c in 'abc']
24+
dissimilarities = numpy.random.rand(2, 6)
1725
rdms = RDMs(
1826
dissimilarities,
19-
rdm_descriptors=dict(xy=['x', 'y']),
20-
pattern_descriptors=dict(abc=numpy.asarray(conds))
27+
rdm_descriptors=dict(xy=[c for c in 'xy']),
28+
pattern_descriptors=dict(abcd=numpy.asarray([c for c in 'abcd']))
2129
)
2230
df = rdms.to_df()
23-
self.assertIsInstance(df, pandas.DataFrame)
24-
self.assertEqual(len(df.columns), 5)
25-
assert_array_equal(df.dissimilarity.values, dissimilarities.ravel())
26-
assert_array_equal(df['rdm_index'].values, ([0]*3) + ([1]*3))
27-
assert_array_equal(df['xy'].values, (['x']*3) + (['y']*3))
28-
assert_array_equal(df['pattern_index'].values, list(range(3))*2)
29-
assert_array_equal(df['abc'].values, conds*2)
31+
self.assertIsInstance(df, DataFrame)
32+
self.assertEqual(len(df.columns), 7)
33+
self.assertValuesEqual(df.dissimilarity, dissimilarities.ravel())
34+
self.assertValuesEqual(df['rdm_index'], ([0]*6) + ([1]*6))
35+
self.assertValuesEqual(df['xy'], (['x']*6) + (['y']*6))
36+
self.assertValuesEqual(df['pattern_index_1'],
37+
([0]*3 + [1]*2 + [2]*1)*2)
38+
self.assertValuesEqual(df['pattern_index_2'], [1, 2, 3, 2, 3, 3]*2)
39+
self.assertValuesEqual(df['abcd_1'], [c for c in 'aaabbc']*2)
40+
self.assertValuesEqual(df['abcd_2'], [c for c in 'bcdcdd']*2)

0 commit comments

Comments
 (0)