|
| 1 | +from __future__ import annotations |
1 | 2 | from unittest import TestCase
|
| 3 | +from typing import TYPE_CHECKING, Union, List |
2 | 4 | from numpy.testing import assert_array_equal
|
3 | 5 | import numpy
|
4 |
| -import pandas |
| 6 | +from pandas import Series, DataFrame |
| 7 | +if TYPE_CHECKING: |
| 8 | + from numpy.typing import NDArray |
5 | 9 |
|
6 | 10 |
|
7 | 11 | class RdmsToPandasTests(TestCase):
|
8 | 12 |
|
| 13 | + def assertValuesEqual(self, |
| 14 | + actual: Series, |
| 15 | + expected: Union[NDArray, List]): |
| 16 | + assert_array_equal(numpy.asarray(actual.values), expected) |
| 17 | + |
9 | 18 | def test_to_df(self):
|
10 | 19 | """Convert an RDMs object to a pandas DataFrame
|
11 | 20 |
|
12 | 21 | Default is long form; multiple rdms are stacked row-wise.
|
13 | 22 | """
|
14 | 23 | from rsatoolbox.rdm.rdms import RDMs
|
15 |
| - dissimilarities = numpy.random.rand(2, 3) |
16 |
| - conds = [c for c in 'abc'] |
| 24 | + dissimilarities = numpy.random.rand(2, 6) |
17 | 25 | rdms = RDMs(
|
18 | 26 | dissimilarities,
|
19 |
| - rdm_descriptors=dict(xy=['x', 'y']), |
20 |
| - pattern_descriptors=dict(abc=numpy.asarray(conds)) |
| 27 | + rdm_descriptors=dict(xy=[c for c in 'xy']), |
| 28 | + pattern_descriptors=dict(abcd=numpy.asarray([c for c in 'abcd'])) |
21 | 29 | )
|
22 | 30 | df = rdms.to_df()
|
23 |
| - self.assertIsInstance(df, pandas.DataFrame) |
24 |
| - self.assertEqual(len(df.columns), 5) |
25 |
| - assert_array_equal(df.dissimilarity.values, dissimilarities.ravel()) |
26 |
| - assert_array_equal(df['rdm_index'].values, ([0]*3) + ([1]*3)) |
27 |
| - assert_array_equal(df['xy'].values, (['x']*3) + (['y']*3)) |
28 |
| - assert_array_equal(df['pattern_index'].values, list(range(3))*2) |
29 |
| - assert_array_equal(df['abc'].values, conds*2) |
| 31 | + self.assertIsInstance(df, DataFrame) |
| 32 | + self.assertEqual(len(df.columns), 7) |
| 33 | + self.assertValuesEqual(df.dissimilarity, dissimilarities.ravel()) |
| 34 | + self.assertValuesEqual(df['rdm_index'], ([0]*6) + ([1]*6)) |
| 35 | + self.assertValuesEqual(df['xy'], (['x']*6) + (['y']*6)) |
| 36 | + self.assertValuesEqual(df['pattern_index_1'], |
| 37 | + ([0]*3 + [1]*2 + [2]*1)*2) |
| 38 | + self.assertValuesEqual(df['pattern_index_2'], [1, 2, 3, 2, 3, 3]*2) |
| 39 | + self.assertValuesEqual(df['abcd_1'], [c for c in 'aaabbc']*2) |
| 40 | + self.assertValuesEqual(df['abcd_2'], [c for c in 'bcdcdd']*2) |
0 commit comments