Skip to content

Commit

Permalink
Merge pull request #285 from rsagroup/colormap-consensus
Browse files Browse the repository at this point in the history
Change default colormap
  • Loading branch information
JasperVanDenBosch authored Dec 15, 2022
2 parents 08feaf4 + 990860f commit f85bfd4
Show file tree
Hide file tree
Showing 14 changed files with 373 additions and 376 deletions.
103 changes: 54 additions & 49 deletions demos/demo_bootstrap.ipynb

Large diffs are not rendered by default.

79 changes: 42 additions & 37 deletions demos/demo_dissimilarities.ipynb

Large diffs are not rendered by default.

89 changes: 47 additions & 42 deletions demos/demo_flexible_models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -72,27 +72,27 @@
"output_type": "stream",
"text": [
"Predicting with theta = [1,0], should return the first rdm, which is:\n",
"[[0.20529161 0.18219921 0.14234779 0.15184744 0.21434174 0.1876438\n",
" 0.23245372 0.16816806 0.11574728 0.21842971]]\n",
"[[0.17948311 0.13203485 0.13632945 0.18298594 0.15166647 0.16234362\n",
" 0.14637612 0.12058838 0.15784657 0.12497874]]\n",
"The output of the model is:\n",
"[0.20529161 0.18219921 0.14234779 0.15184744 0.21434174 0.1876438\n",
" 0.23245372 0.16816806 0.11574728 0.21842971]\n",
"[0.17948311 0.13203485 0.13632945 0.18298594 0.15166647 0.16234362\n",
" 0.14637612 0.12058838 0.15784657 0.12497874]\n",
"Which is indeed identical\n",
"\n",
"Predicting with theta = [0,1], should return the second rdm, which is:\n",
"[[0.21021824 0.09930457 0.16356345 0.20092431 0.18402425 0.19685312\n",
" 0.13494642 0.13652705 0.1714637 0.12152749]]\n",
"[[0.14265908 0.17005546 0.18177815 0.1564901 0.13836761 0.17725487\n",
" 0.09220979 0.24478309 0.15514355 0.20351734]]\n",
"The output of the model is:\n",
"[0.21021824 0.09930457 0.16356345 0.20092431 0.18402425 0.19685312\n",
" 0.13494642 0.13652705 0.1714637 0.12152749]\n",
"[0.14265908 0.17005546 0.18177815 0.1564901 0.13836761 0.17725487\n",
" 0.09220979 0.24478309 0.15514355 0.20351734]\n",
"Which is indeed identical\n",
"\n",
"Predicting with theta = [1,1], should return the sum of the first two rdms, which is:\n",
"[[0.41550985 0.28150378 0.30591124 0.35277176 0.39836599 0.38449692\n",
" 0.36740013 0.3046951 0.28721099 0.3399572 ]]\n",
"[[0.32214218 0.30209031 0.3181076 0.33947604 0.29003408 0.33959849\n",
" 0.23858592 0.36537148 0.31299012 0.32849607]]\n",
"The output of the model is:\n",
"[0.41550985 0.28150378 0.30591124 0.35277176 0.39836599 0.38449692\n",
" 0.36740013 0.3046951 0.28721099 0.3399572 ]\n",
"[0.32214218 0.30209031 0.3181076 0.33947604 0.29003408 0.33959849\n",
" 0.23858592 0.36537148 0.31299012 0.32849607]\n",
"Which is indeed identical\n"
]
}
Expand Down Expand Up @@ -136,11 +136,11 @@
"squared euclidean\n",
"\n",
"dissimilarities[0] = \n",
"[[0. 0.20529161 0.18219921 0.14234779 0.15184744]\n",
" [0.20529161 0. 0.21434174 0.1876438 0.23245372]\n",
" [0.18219921 0.21434174 0. 0.16816806 0.11574728]\n",
" [0.14234779 0.1876438 0.16816806 0. 0.21842971]\n",
" [0.15184744 0.23245372 0.11574728 0.21842971 0. ]]\n",
"[[0. 0.17948311 0.13203485 0.13632945 0.18298594]\n",
" [0.17948311 0. 0.15166647 0.16234362 0.14637612]\n",
" [0.13203485 0.15166647 0. 0.12058838 0.15784657]\n",
" [0.13632945 0.16234362 0.12058838 0. 0.12497874]\n",
" [0.18298594 0.14637612 0.15784657 0.12497874 0. ]]\n",
"\n",
"descriptors: \n",
"\n",
Expand All @@ -161,11 +161,11 @@
"squared euclidean\n",
"\n",
"dissimilarities[0] = \n",
"[[0. 0.20529161 0.18219921 0.14234779 0.15184744]\n",
" [0.20529161 0. 0.21434174 0.1876438 0.23245372]\n",
" [0.18219921 0.21434174 0. 0.16816806 0.11574728]\n",
" [0.14234779 0.1876438 0.16816806 0. 0.21842971]\n",
" [0.15184744 0.23245372 0.11574728 0.21842971 0. ]]\n",
"[[0. 0.17948311 0.13203485 0.13632945 0.18298594]\n",
" [0.17948311 0. 0.15166647 0.16234362 0.14637612]\n",
" [0.13203485 0.15166647 0. 0.12058838 0.15784657]\n",
" [0.13632945 0.16234362 0.12058838 0. 0.12497874]\n",
" [0.18298594 0.14637612 0.15784657 0.12497874 0. ]]\n",
"\n",
"descriptors: \n",
"\n",
Expand Down Expand Up @@ -209,9 +209,9 @@
"output_type": "stream",
"text": [
"Theta based on optimization:\n",
"[0.54167376 0.8405888 ]\n",
"[0.94571213 0.32500549]\n",
"Theta based on fit_regress:\n",
"[0.54166958 0.8405915 ]\n"
"[0.94571213 0.32500548]\n"
]
}
],
Expand Down Expand Up @@ -245,10 +245,10 @@
{
"data": {
"text/plain": [
"(<Figure size 144x144 with 1 Axes>,\n",
" array([[<AxesSubplot:>]], dtype=object),\n",
"(<Figure size 200x200 with 1 Axes>,\n",
" array([[<AxesSubplot: >]], dtype=object),\n",
" defaultdict(dict,\n",
" {<AxesSubplot:>: {'image': <matplotlib.image.AxesImage at 0x7fba190f22e0>}}))"
" {<AxesSubplot: >: {'image': <matplotlib.image.AxesImage at 0x17f6279d0>}}))"
]
},
"execution_count": 6,
Expand All @@ -257,19 +257,19 @@
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAHsAAAB7CAYAAABUx/9/AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAACWklEQVR4nO3dsWoUURhH8f8Viy232C1i4ySNhWCVcp8ibCpLQWx9CGvxASSNgYAh3b7FTmsppE1CIGIpXgvb3anujeA5v/bCNwOHmerb2VJrjRie/Osb0OMxNoixQYwNYmwQY4M8nTpcLBZ1GIYOlx07zEzG8aD5zNnhrPnMJDm6v+8y99vDw12tdbnrbDL2MAzZbrcdbql0mJmU8rb5zKMPL5rPTJKL8/Muc19tNtf7znyNgxgbxNggxgYxNoixQYwNYmwQY4MYG8TYIMYGMTaIsUGMDWJsEGODGBvE2CDGBplcOPy7BdpjObDPjwlvXre/11+fmo9Mknx/d9Zn8Gaz98gnG8TYIMYGMTaIsUGMDWJsEGODGBvE2CDGBjE2iLFBjA1ibBBjgxgbxNggxgYxNoixQSa3S8fxoMsnIntsgSbJ8kv7rdXy8XPzmUmSn7/7zJ3gkw1ibBBjgxgbxNggxgYxNoixQYwNYmwQY4MYG8TYIMYGMTaIsUGMDWJsEGODGBtkcuFwdjjr8v/RvT4R2WM5sL5/03xmkpxeXXWZezlx5pMNYmwQY4MYG8TYIMYGMTaIsUGMDWJsEGODGBvE2CDGBjE2iLFBjA1ibBBjgxgbxNggpdb9n4B8OZ/Xi9Wq+UV/rNfNZybJqsMnItfP5s1nJsnXk5Muc0spY631eNeZTzaIsUGMDWJsEGODGBvE2CDGBjE2iLFBjA1ibBBjgxgbxNggxgYxNoixQYwNYmwQY4NMbpeWUm6TXD/e7aiB57XW5a6Dydj6v/gaBzE2iLFBjA1ibJA/T8ZLZyukztQAAAAASUVORK5CYII=\n",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAK4AAACuCAYAAACvDDbuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAADx0lEQVR4nO3csWpbZxjH4Vd2nKmS1tboQLyGQKF7IbfREsjWJSUXEELHQju00LHQudlC5y7pHeQWjIx2HQ1JiNHppNAOlY8b21/+7fPMH34/WT8OR8s7GYZhKAhz0PoC8G8Il0jCJZJwiSRcIgmXSMIlknCJdGvMoe12W6vVqqbTaU0mk+u+E/9TwzDUZrOp4+PjOjjY/0wdFe5qtaqu667kcnCR5XJZi8Vi75lR4U6n03d/cDabvf/NLunl6emNz9x59OWjJnO77m6TuVVVz55932Ru3/fVdd273vYZFe7u9WA2mzUJ96MRH+S6HB6O+hdduaOj203mVlWT7/ivxryO+nFGJOESSbhEEi6RhEsk4RJJuEQSLpGESyThEkm4RBIukYRLJOESSbhEEi6RhEsk4RJJuEQSLpGESyThEkm4RBIukYRLJOESSbhEEi6RLrXR7eXpaZMFdJ/duXPjM3dOTj5tMvfh0wdN5laNWzrXmicukYRLJOESSbhEEi6RhEsk4RJJuEQSLpGESyThEkm4RBIukYRLJOESSbhEEi6RhEsk4RJJuEQSLpGESyThEkm4RBIukYRLJOESSbhEEi6RhEukyTAMw0WH+r6v+Xxe9+59XoeHl9pMeiVarfqsqnr+/Mcmc796/G2TuVVVn5x83GTum9ev6rsnX9d6va7ZbLb3rCcukYRLJOESSbhEEi6RhEsk4RJJuEQSLpGESyThEkm4RBIukYRLJOESSbhEEi6RhEsk4RJJuEQSLpGESyThEkm4RBIukYRLJOESSbhEEi6RLrV6sevu1tHR7eu6yz96+PTBjc/cabU18eefnjSZW1V1//4XTeaen78dfdYTl0jCJZJwiSRcIgmXSMIlknCJJFwiCZdIwiWScIkkXCIJl0jCJZJwiSRcIgmXSMIlknCJJFwiCZdIwiWScIkkXCIJl0jCJZJwiSRcIgmXSMIl0mQYhuGiQ33f13w+r/V6XbPZ7Cbu9TeTyeTGZ+5888MvTeb+8dvvTeZWVb148WuTuZfpzBOXSMIlknCJJFwiCZdIwiWScIkkXCIJl0jCJZJwiSRcIgmXSMIlknCJJFwiCZdIwiWScIkkXCIJl0jCJZJwiSRcIgmXSMIlknCJJFwiCZdIt8Yc2i107Pv+Wi/zIXrz+lWTuefnb5vMrWr3Pe/mjlggOm7N6NnZWXVd9/43gxGWy2UtFou9Z0aFu91ua7Va1XQ6bbqrlv+2YRhqs9nU8fFxHRzsf4sdFS58aPw4I5JwiSRcIgmXSMIlknCJJFwi/Qmeoo0IvbUyEAAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 144x144 with 1 Axes>"
"<Figure size 200x200 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAHsAAAB7CAYAAABUx/9/AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAACWklEQVR4nO3dsWoUURhH8f8Viy232C1i4ySNhWCVcp8ibCpLQWx9CGvxASSNgYAh3b7FTmsppE1CIGIpXgvb3anujeA5v/bCNwOHmerb2VJrjRie/Osb0OMxNoixQYwNYmwQY4M8nTpcLBZ1GIYOlx07zEzG8aD5zNnhrPnMJDm6v+8y99vDw12tdbnrbDL2MAzZbrcdbql0mJmU8rb5zKMPL5rPTJKL8/Muc19tNtf7znyNgxgbxNggxgYxNoixQYwNYmwQY4MYG8TYIMYGMTaIsUGMDWJsEGODGBvE2CDGBplcOPy7BdpjObDPjwlvXre/11+fmo9Mknx/d9Zn8Gaz98gnG8TYIMYGMTaIsUGMDWJsEGODGBvE2CDGBjE2iLFBjA1ibBBjgxgbxNggxgYxNoixQSa3S8fxoMsnIntsgSbJ8kv7rdXy8XPzmUmSn7/7zJ3gkw1ibBBjgxgbxNggxgYxNoixQYwNYmwQY4MYG8TYIMYGMTaIsUGMDWJsEGODGBtkcuFwdjjr8v/RvT4R2WM5sL5/03xmkpxeXXWZezlx5pMNYmwQY4MYG8TYIMYGMTaIsUGMDWJsEGODGBvE2CDGBjE2iLFBjA1ibBBjgxgbxNggpdb9n4B8OZ/Xi9Wq+UV/rNfNZybJqsMnItfP5s1nJsnXk5Muc0spY631eNeZTzaIsUGMDWJsEGODGBvE2CDGBjE2iLFBjA1ibBBjgxgbxNggxgYxNoixQYwNYmwQY4NMbpeWUm6TXD/e7aiB57XW5a6Dydj6v/gaBzE2iLFBjA1ibJA/T8ZLZyukztQAAAAASUVORK5CYII=\n",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAK4AAACuCAYAAACvDDbuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAADx0lEQVR4nO3csWpbZxjH4Vd2nKmS1tboQLyGQKF7IbfREsjWJSUXEELHQju00LHQudlC5y7pHeQWjIx2HQ1JiNHppNAOlY8b21/+7fPMH34/WT8OR8s7GYZhKAhz0PoC8G8Il0jCJZJwiSRcIgmXSMIlknCJdGvMoe12W6vVqqbTaU0mk+u+E/9TwzDUZrOp4+PjOjjY/0wdFe5qtaqu667kcnCR5XJZi8Vi75lR4U6n03d/cDabvf/NLunl6emNz9x59OWjJnO77m6TuVVVz55932Ru3/fVdd273vYZFe7u9WA2mzUJ96MRH+S6HB6O+hdduaOj203mVlWT7/ivxryO+nFGJOESSbhEEi6RhEsk4RJJuEQSLpGESyThEkm4RBIukYRLJOESSbhEEi6RhEsk4RJJuEQSLpGESyThEkm4RBIukYRLJOESSbhEEi6RLrXR7eXpaZMFdJ/duXPjM3dOTj5tMvfh0wdN5laNWzrXmicukYRLJOESSbhEEi6RhEsk4RJJuEQSLpGESyThEkm4RBIukYRLJOESSbhEEi6RhEsk4RJJuEQSLpGESyThEkm4RBIukYRLJOESSbhEEi6RhEukyTAMw0WH+r6v+Xxe9+59XoeHl9pMeiVarfqsqnr+/Mcmc796/G2TuVVVn5x83GTum9ev6rsnX9d6va7ZbLb3rCcukYRLJOESSbhEEi6RhEsk4RJJuEQSLpGESyThEkm4RBIukYRLJOESSbhEEi6RhEsk4RJJuEQSLpGESyThEkm4RBIukYRLJOESSbhEEi6RLrV6sevu1tHR7eu6yz96+PTBjc/cabU18eefnjSZW1V1//4XTeaen78dfdYTl0jCJZJwiSRcIgmXSMIlknCJJFwiCZdIwiWScIkkXCIJl0jCJZJwiSRcIgmXSMIlknCJJFwiCZdIwiWScIkkXCIJl0jCJZJwiSRcIgmXSMIl0mQYhuGiQ33f13w+r/V6XbPZ7Cbu9TeTyeTGZ+5888MvTeb+8dvvTeZWVb148WuTuZfpzBOXSMIlknCJJFwiCZdIwiWScIkkXCIJl0jCJZJwiSRcIgmXSMIlknCJJFwiCZdIwiWScIkkXCIJl0jCJZJwiSRcIgmXSMIlknCJJFwiCZdIt8Yc2i107Pv+Wi/zIXrz+lWTuefnb5vMrWr3Pe/mjlggOm7N6NnZWXVd9/43gxGWy2UtFou9Z0aFu91ua7Va1XQ6bbqrlv+2YRhqs9nU8fFxHRzsf4sdFS58aPw4I5JwiSRcIgmXSMIlknCJJFwi/Qmeoo0IvbUyEAAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 144x144 with 1 Axes>"
"<Figure size 200x200 with 1 Axes>"
]
},
"metadata": {},
Expand Down Expand Up @@ -301,9 +301,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[0.54171328 0.84056334]\n",
"[0.94571213 0.32500549]\n",
"the used fitting function was:\n",
"<function fit_optimize at 0x7fba18bdb9d0>\n"
"<function fit_optimize at 0x153191000>\n"
]
}
],
Expand Down Expand Up @@ -338,9 +338,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[-0.04123096 0.99914964]\n",
"[-0.04123086 0.99914965]\n",
"[-0.04122998 0.99914968]\n"
"[ 0.996767 -0.08034645]\n",
"[ 0.996767 -0.08034645]\n",
"[ 0.996767 -0.08034647]\n"
]
}
],
Expand Down Expand Up @@ -378,13 +378,13 @@
"output_type": "stream",
"text": [
"The average correlation for the correlation parameters is:\n",
"0.20919570220266936\n",
"0.0890392385172882\n",
"The average correlation for the cosine similarity parameters is:\n",
"0.1648685520549055\n",
"0.06085602619789239\n",
"The average cosine similarity for the correlation parameters is:\n",
"0.9609090694876308\n",
"0.9649654254976477\n",
"The average cosine similarity for the cosine similarity parameters is:\n",
"0.9712386973494105\n"
"0.9721299166013238\n"
]
}
],
Expand Down Expand Up @@ -432,7 +432,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "env",
"language": "python",
"name": "python3"
},
Expand All @@ -446,7 +446,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
"version": "3.10.4"
},
"vscode": {
"interpreter": {
"hash": "af6f0c1be22da210ce14b764d3d407b4e31df46360687c396ac7d1fbf0a9a76f"
}
}
},
"nbformat": 4,
Expand Down
17 changes: 11 additions & 6 deletions demos/demo_rdm_comparison_scatterplot.ipynb

Large diffs are not rendered by default.

226 changes: 95 additions & 131 deletions demos/demo_rdm_visualisation_92images.ipynb

Large diffs are not rendered by default.

155 changes: 78 additions & 77 deletions demos/demo_temporal.ipynb

Large diffs are not rendered by default.

Binary file modified demos/temp_rdm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 4 additions & 1 deletion src/rsatoolbox/util/vis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,8 @@ class Weighted_MDS(BaseEstimator):

def __init__(self, n_components=2, *, metric=True, n_init=4,
max_iter=300, verbose=0, eps=1e-3, n_jobs=None,
random_state=None, dissimilarity="euclidean"):
random_state=None, dissimilarity="euclidean",
normalized_stress='auto'):
self.n_components = n_components
self.dissimilarity = dissimilarity
self.metric = metric
Expand All @@ -436,6 +437,8 @@ def __init__(self, n_components=2, *, metric=True, n_init=4,
self.embedding_ = None
self.stress_ = None
self.n_iter_ = None
# not in use, declared for consistency with sklearn:
self.normalized_stress = normalized_stress

@property
def _pairwise(self):
Expand Down
14 changes: 6 additions & 8 deletions src/rsatoolbox/vis/colors.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
#!/usr/bin/python
# -*- coding: UTF-8 -*-
"""
Definition of rsatoolbox's colors
Classic colormap ported from matlab rsatoolbox
@author: iancharest
"""

from __future__ import annotations
import numpy as np
from skimage.color import rgb2hsv, hsv2rgb
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from scipy.interpolate import interp1d


def color_scale(n_cols, anchor_cols=None, monitor=False):
def color_scale(n_cols: int, anchor_cols=None, monitor=False):
""" linearly interpolates between a set of given
anchor colours to give n_cols and displays them
if monitor is set
Expand Down Expand Up @@ -55,7 +53,7 @@ def color_scale(n_cols, anchor_cols=None, monitor=False):
return cols


def rdm_colormap(n_cols=256, monitor=None):
def rdm_colormap_classic(n_cols: int = 256, monitor: bool = False):
"""this function provides a convenient colormap for visualizing
dissimilarity matrices. it goes from blue to yellow and has grey for
intermediate values.
Expand All @@ -73,8 +71,8 @@ def rdm_colormap(n_cols=256, monitor=None):
import numpy as np
import matplotlib.pyplot as plt
from rsatoolbox.vis.colors import rdm_colormap
plt.imshow(np.random.rand(10,10),cmap=rdm_colormap())
from rsatoolbox.vis.colors import rdm_colormap_classic
plt.imshow(np.random.rand(10,10),cmap=rdm_colormap_classic())
plt.colorbar()
plt.show()
Expand Down
8 changes: 6 additions & 2 deletions src/rsatoolbox/vis/icon.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import os
import matplotlib.pyplot as plt
from matplotlib import cm
import matplotlib
from matplotlib.offsetbox import OffsetImage, AnnotationBbox, DrawingArea
import numpy as np
import PIL
Expand All @@ -15,6 +15,10 @@
from PIL import UnidentifiedImageError
from rsatoolbox.rdm import RDMs
from rsatoolbox.util.pooling import pool_rdm
if hasattr(matplotlib.colormaps, 'get_cmap'):
mpl_get_cmap = matplotlib.colormaps.get_cmap
else:
mpl_get_cmap = matplotlib.cm.get_cmap # drop:py37


class Icon:
Expand Down Expand Up @@ -243,7 +247,7 @@ def recompute_final_image(self):
else:
im = self._image
if self.cmap is not None:
im = cm.get_cmap(self.cmap)(im)
im = mpl_get_cmap(self.cmap)(im)
im = PIL.Image.fromarray((im * 255).astype(np.uint8))
else: # we hope it is a PIL image or equivalent
im = self._image
Expand Down
27 changes: 16 additions & 11 deletions src/rsatoolbox/vis/rdm_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,27 @@

from __future__ import annotations
import collections
from typing import TYPE_CHECKING, Union, Tuple
from typing import TYPE_CHECKING, Union, Tuple, Optional
import pkg_resources
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import rsatoolbox.rdm
from rsatoolbox import vis
from rsatoolbox.vis.colors import rdm_colormap
from rsatoolbox.vis.colors import rdm_colormap_classic
if TYPE_CHECKING:
import numpy.typing as npt
import pathlib
from matplotlib.axes._axes import Axes
from matplotlib.colors import Colormap

RDM_STYLE = pkg_resources.resource_filename('rsatoolbox.vis', 'rdm.mplstyle')


def show_rdm(
rdm: rsatoolbox.rdm.RDMs,
pattern_descriptor: str = None,
cmap: Union[str, matplotlib.colors.Colormap] = None,
cmap: Union[str, Colormap] = 'bone',
rdm_descriptor: str = None,
n_column: int = None,
n_row: int = None,
Expand All @@ -48,8 +49,10 @@ def show_rdm(
rdm (rsatoolbox.rdm.RDMs): RDMs object to be plotted.
pattern_descriptor (str): Key into rdm.pattern_descriptors to use for axis
labels.
cmap (Union[str, matplotlib.colors.Colormap]): colormap to be used (by
plt.imshow internally). By default we use rdm_colormap.
cmap (str or Colormap): Colormap to be used.
Either the name of a Matplotlib built-in colormap, a Matplotlib
Colormap compatible object, or 'classic' for the matlab toolbox
colormap. Defaults to 'bone'.
rdm_descriptor (str): Key for rdm_descriptor to use as panel title, or
str for direct labeling.
n_column (int): Number of columns in subplot arrangement.
Expand Down Expand Up @@ -255,8 +258,8 @@ def _rdm_colorbar(

def show_rdm_panel(
rdm: rsatoolbox.rdm.RDMs,
ax: Axes = None,
cmap: Union[str, matplotlib.colors.Colormap] = None,
ax: Optional[Axes] = None,
cmap: Union[str, Colormap] = 'bone',
nanmask: npt.ArrayLike = None,
rdm_descriptor: str = None,
gridlines: npt.ArrayLike = None,
Expand All @@ -268,8 +271,10 @@ def show_rdm_panel(
Args:
rdm (rsatoolbox.rdm.RDMs): RDMs object to be plotted (n_rdm must be 1).
ax (matplotlib.axes._axes.Axes): Matplotlib axis handle. plt.gca() by default.
cmap (Union[str, matplotlib.colors.Colormap]): colormap to be used (by
plt.imshow internally). By default we use rdm_colormap.
cmap (str or Colormap): Colormap to be used.
Either the name of a Matplotlib built-in colormap, a Matplotlib
Colormap compatible object, or 'classic' for the matlab toolbox
colormap. Defaults to 'bone'.
nanmask (npt.ArrayLike): boolean mask defining RDM elements to suppress
(by default, the diagonals).
rdm_descriptor (str): Key for rdm_descriptor to use as panel title, or
Expand All @@ -287,8 +292,8 @@ def show_rdm_panel(
raise ValueError("expected single rdm - use show_rdm for multi-panel figures")
if ax is None:
ax = plt.gca()
if cmap is None:
cmap = rdm_colormap()
if cmap == 'classic':
cmap = rdm_colormap_classic()
if nanmask is None:
nanmask = np.eye(rdm.n_cond, dtype=bool)
if not np.any(gridlines):
Expand Down
7 changes: 6 additions & 1 deletion src/rsatoolbox/vis/scatter_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,12 @@ def show_2d(
"""
if method == 'MDS':
MDS = sklearn.manifold.MDS if weights is None else Weighted_MDS
embedding = MDS(n_components=2, random_state=seed, dissimilarity='precomputed')
embedding = MDS(
n_components=2,
random_state=seed,
dissimilarity='precomputed',
# normalized_stress='auto' # drop:py37
)
elif method == 't-SNE':
embedding = sklearn.manifold.TSNE(n_components=2)
elif method == 'Isomap':
Expand Down
Loading

0 comments on commit f85bfd4

Please sign in to comment.