diff --git a/CHANGELOG.md b/CHANGELOG.md index 39d33089967..564cc774fa1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,26 @@ All notable changes to this project will be documented in this file. This project adheres to [Semantic Versioning](http://semver.org/). +## [2.0.0] + +### Changed +- `plotly.exceptions.PlotlyRequestException` is *always* raised for network +failures. Previously either a `PlotlyError`, `PlotlyRequestException`, or a +`requests.exceptions.ReqestException` could be raised. In particular, scripts +which depend on `try-except` blocks containing network requests should be +revisited. +- `plotly.py:sign_in` now validates to the plotly server specified in your + config. If it cannot make a successful request, it raises a `PlotlyError`. +- `plotly.figure_factory` will raise an `ImportError` if `numpy` is not + installed. + +### Deprecated +- `plotly.tools.FigureFactory`. Use `plotly.figure_factory.*`. +- (optional imports) `plotly.tools._*_imported` It was private anyhow, but now +it's gone. (e.g., `_numpy_imported`) +- (plotly v2 helper) `plotly.py._api_v2` It was private anyhow, but now it's +gone. + ## [1.13.0] - 2016-01-17 ### Added - Python 3.5 has been added as a tested environment for this package. diff --git a/circle.yml b/circle.yml index 62a7bdb505a..29fb8d4fa11 100644 --- a/circle.yml +++ b/circle.yml @@ -52,5 +52,9 @@ test: - sudo chmod -R 444 ${PLOTLY_CONFIG_DIR} && python -c "import plotly" # test that giving back write permissions works again - # this also has to pass the test suite that follows - sudo chmod -R 777 ${PLOTLY_CONFIG_DIR} && python -c "import plotly" + + # test that figure_factory cannot be imported with only core requirements. + # since optional requirements is part of the test suite, we don't need to + # worry about testing that it *can* be imported in this case. + - $(! python -c "import plotly.figure_factory") diff --git a/optional-requirements.txt b/optional-requirements.txt index 9dcc04023b1..ada6fb8598d 100644 --- a/optional-requirements.txt +++ b/optional-requirements.txt @@ -12,8 +12,9 @@ numpy # matplotlib==1.3.1 ## testing dependencies ## -nose -coverage +coverage==4.3.1 +mock==2.0.0 +nose==1.3.3 ## ipython ## ipython diff --git a/plotly/api/__init__.py b/plotly/api/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/plotly/api/utils.py b/plotly/api/utils.py new file mode 100644 index 00000000000..d9d1d21f504 --- /dev/null +++ b/plotly/api/utils.py @@ -0,0 +1,41 @@ +from base64 import b64encode + +from requests.compat import builtin_str, is_py2 + + +def _to_native_string(string, encoding): + if isinstance(string, builtin_str): + return string + if is_py2: + return string.encode(encoding) + return string.decode(encoding) + + +def to_native_utf8_string(string): + return _to_native_string(string, 'utf-8') + + +def to_native_ascii_string(string): + return _to_native_string(string, 'ascii') + + +def basic_auth(username, password): + """ + Creates the basic auth value to be used in an authorization header. + + This is mostly copied from the requests library. + + :param (str) username: A Plotly username. + :param (str) password: The password for the given Plotly username. + :returns: (str) An 'authorization' header for use in a request header. + + """ + if isinstance(username, str): + username = username.encode('latin1') + + if isinstance(password, str): + password = password.encode('latin1') + + return 'Basic ' + to_native_ascii_string( + b64encode(b':'.join((username, password))).strip() + ) diff --git a/plotly/api/v1/__init__.py b/plotly/api/v1/__init__.py new file mode 100644 index 00000000000..a43ff61f4c8 --- /dev/null +++ b/plotly/api/v1/__init__.py @@ -0,0 +1,3 @@ +from __future__ import absolute_import + +from plotly.api.v1.clientresp import clientresp diff --git a/plotly/api/v1/clientresp.py b/plotly/api/v1/clientresp.py new file mode 100644 index 00000000000..c3af66c6b1c --- /dev/null +++ b/plotly/api/v1/clientresp.py @@ -0,0 +1,44 @@ +"""Interface to deprecated /clientresp API. Subject to deletion.""" +from __future__ import absolute_import + +import warnings + +from requests.compat import json as _json + +from plotly import config, utils, version +from plotly.api.v1.utils import request + + +def clientresp(data, **kwargs): + """ + Deprecated endpoint, still used because it can parse data out of a plot. + + When we get around to forcing users to create grids and then create plots, + we can finally get rid of this. + + :param (list) data: The data array from a figure. + + """ + creds = config.get_credentials() + cfg = config.get_config() + + dumps_kwargs = {'sort_keys': True, 'cls': utils.PlotlyJSONEncoder} + + payload = { + 'platform': 'python', 'version': version.__version__, + 'args': _json.dumps(data, **dumps_kwargs), + 'un': creds['username'], 'key': creds['api_key'], 'origin': 'plot', + 'kwargs': _json.dumps(kwargs, **dumps_kwargs) + } + + url = '{plotly_domain}/clientresp'.format(**cfg) + response = request('post', url, data=payload) + + # Old functionality, just keeping it around. + parsed_content = response.json() + if parsed_content.get('warning'): + warnings.warn(parsed_content['warning']) + if parsed_content.get('message'): + print(parsed_content['message']) + + return response diff --git a/plotly/api/v1/utils.py b/plotly/api/v1/utils.py new file mode 100644 index 00000000000..abfdf745c3e --- /dev/null +++ b/plotly/api/v1/utils.py @@ -0,0 +1,87 @@ +from __future__ import absolute_import + +import requests +from requests.exceptions import RequestException + +from plotly import config, exceptions +from plotly.api.utils import basic_auth + + +def validate_response(response): + """ + Raise a helpful PlotlyRequestError for failed requests. + + :param (requests.Response) response: A Response object from an api request. + :raises: (PlotlyRequestError) If the request failed for any reason. + :returns: (None) + + """ + content = response.content + status_code = response.status_code + try: + parsed_content = response.json() + except ValueError: + message = content if content else 'No Content' + raise exceptions.PlotlyRequestError(message, status_code, content) + + message = '' + if isinstance(parsed_content, dict): + error = parsed_content.get('error') + if error: + message = error + else: + if response.ok: + return + if not message: + message = content if content else 'No Content' + + raise exceptions.PlotlyRequestError(message, status_code, content) + + +def get_headers(): + """ + Using session credentials/config, get headers for a v1 API request. + + Users may have their own proxy layer and so we free up the `authorization` + header for this purpose (instead adding the user authorization in a new + `plotly-authorization` header). See pull #239. + + :returns: (dict) Headers to add to a requests.request call. + + """ + headers = {} + creds = config.get_credentials() + proxy_auth = basic_auth(creds['proxy_username'], creds['proxy_password']) + + if config.get_config()['plotly_proxy_authorization']: + headers['authorization'] = proxy_auth + + return headers + + +def request(method, url, **kwargs): + """ + Central place to make any v1 api request. + + :param (str) method: The request method ('get', 'put', 'delete', ...). + :param (str) url: The full api url to make the request to. + :param kwargs: These are passed along to requests. + :return: (requests.Response) The response directly from requests. + + """ + if kwargs.get('json', None) is not None: + # See plotly.api.v2.utils.request for examples on how to do this. + raise exceptions.PlotlyError('V1 API does not handle arbitrary json.') + kwargs['headers'] = dict(kwargs.get('headers', {}), **get_headers()) + kwargs['verify'] = config.get_config()['plotly_ssl_verification'] + try: + response = requests.request(method, url, **kwargs) + except RequestException as e: + # The message can be an exception. E.g., MaxRetryError. + message = str(getattr(e, 'message', 'No message')) + response = getattr(e, 'response', None) + status_code = response.status_code if response else None + content = response.content if response else 'No content' + raise exceptions.PlotlyRequestError(message, status_code, content) + validate_response(response) + return response diff --git a/plotly/api/v2/__init__.py b/plotly/api/v2/__init__.py new file mode 100644 index 00000000000..8424927d1c6 --- /dev/null +++ b/plotly/api/v2/__init__.py @@ -0,0 +1,4 @@ +from __future__ import absolute_import + +from plotly.api.v2 import (files, folders, grids, images, plot_schema, plots, + users) diff --git a/plotly/api/v2/files.py b/plotly/api/v2/files.py new file mode 100644 index 00000000000..650ab48fc85 --- /dev/null +++ b/plotly/api/v2/files.py @@ -0,0 +1,85 @@ +"""Interface to Plotly's /v2/files endpoints.""" +from __future__ import absolute_import + +from plotly.api.v2.utils import build_url, make_params, request + +RESOURCE = 'files' + + +def retrieve(fid, share_key=None): + """ + Retrieve a general file from Plotly. + + :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`. + :param (str) share_key: The secret key granting 'read' access if private. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, id=fid) + params = make_params(share_key=share_key) + return request('get', url, params=params) + + +def update(fid, body): + """ + Update a general file from Plotly. + + :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`. + :param (dict) body: A mapping of body param names to values. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, id=fid) + return request('put', url, json=body) + + +def trash(fid): + """ + Soft-delete a general file from Plotly. (Can be undone with 'restore'). + + :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, id=fid, route='trash') + return request('post', url) + + +def restore(fid): + """ + Restore a trashed, general file from Plotly. See 'trash'. + + :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, id=fid, route='restore') + return request('post', url) + + +def permanent_delete(fid): + """ + Permanently delete a trashed, general file from Plotly. See 'trash'. + + :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, id=fid, route='permanent_delete') + return request('delete', url) + + +def lookup(path, parent=None, user=None, exists=None): + """ + Retrieve a general file from Plotly without needing a fid. + + :param (str) path: The '/'-delimited path specifying the file location. + :param (int) parent: Parent id, an integer, which the path is relative to. + :param (str) user: The username to target files for. Defaults to requestor. + :param (bool) exists: If True, don't return the full file, just a flag. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, route='lookup') + params = make_params(path=path, parent=parent, user=user, exists=exists) + return request('get', url, params=params) diff --git a/plotly/api/v2/folders.py b/plotly/api/v2/folders.py new file mode 100644 index 00000000000..2dcf84670e7 --- /dev/null +++ b/plotly/api/v2/folders.py @@ -0,0 +1,103 @@ +"""Interface to Plotly's /v2/folders endpoints.""" +from __future__ import absolute_import + +from plotly.api.v2.utils import build_url, make_params, request + +RESOURCE = 'folders' + + +def create(body): + """ + Create a new folder. + + :param (dict) body: A mapping of body param names to values. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE) + return request('post', url, json=body) + + +def retrieve(fid, share_key=None): + """ + Retrieve a folder from Plotly. + + :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`. + :param (str) share_key: The secret key granting 'read' access if private. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, id=fid) + params = make_params(share_key=share_key) + return request('get', url, params=params) + + +def update(fid, body): + """ + Update a folder from Plotly. + + :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`. + :param (dict) body: A mapping of body param names to values. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, id=fid) + return request('put', url, json=body) + + +def trash(fid): + """ + Soft-delete a folder from Plotly. (Can be undone with 'restore'). + + This action is recursively done on files inside the folder. + + :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, id=fid, route='trash') + return request('post', url) + + +def restore(fid): + """ + Restore a trashed folder from Plotly. See 'trash'. + + This action is recursively done on files inside the folder. + + :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, id=fid, route='restore') + return request('post', url) + + +def permanent_delete(fid): + """ + Permanently delete a trashed folder file from Plotly. See 'trash'. + + This action is recursively done on files inside the folder. + + :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, id=fid, route='permanent_delete') + return request('delete', url) + + +def lookup(path, parent=None, user=None, exists=None): + """ + Retrieve a folder file from Plotly without needing a fid. + + :param (str) path: The '/'-delimited path specifying the file location. + :param (int) parent: Parent id, an integer, which the path is relative to. + :param (str) user: The username to target files for. Defaults to requestor. + :param (bool) exists: If True, don't return the full file, just a flag. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, route='lookup') + params = make_params(path=path, parent=parent, user=user, exists=exists) + return request('get', url, params=params) diff --git a/plotly/api/v2/grids.py b/plotly/api/v2/grids.py new file mode 100644 index 00000000000..144ec3bd23f --- /dev/null +++ b/plotly/api/v2/grids.py @@ -0,0 +1,180 @@ +"""Interface to Plotly's /v2/grids endpoints.""" +from __future__ import absolute_import + +from plotly.api.v2.utils import build_url, make_params, request + +RESOURCE = 'grids' + + +def create(body): + """ + Create a new grid. + + :param (dict) body: A mapping of body param names to values. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE) + return request('post', url, json=body) + + +def retrieve(fid, share_key=None): + """ + Retrieve a grid from Plotly. + + :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`. + :param (str) share_key: The secret key granting 'read' access if private. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, id=fid) + params = make_params(share_key=share_key) + return request('get', url, params=params) + + +def content(fid, share_key=None): + """ + Retrieve full content for the grid (normal retrieve only yields preview) + + :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`. + :param (str) share_key: The secret key granting 'read' access if private. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, id=fid, route='content') + params = make_params(share_key=share_key) + return request('get', url, params=params) + + +def update(fid, body): + """ + Update a grid from Plotly. + + :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`. + :param (dict) body: A mapping of body param names to values. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, id=fid) + return request('put', url, json=body) + + +def trash(fid): + """ + Soft-delete a grid from Plotly. (Can be undone with 'restore'). + + :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, id=fid, route='trash') + return request('post', url) + + +def restore(fid): + """ + Restore a trashed grid from Plotly. See 'trash'. + + :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, id=fid, route='restore') + return request('post', url) + + +def permanent_delete(fid): + """ + Permanently delete a trashed grid file from Plotly. See 'trash'. + + :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, id=fid, route='permanent_delete') + return request('delete', url) + + +def lookup(path, parent=None, user=None, exists=None): + """ + Retrieve a grid file from Plotly without needing a fid. + + :param (str) path: The '/'-delimited path specifying the file location. + :param (int) parent: Parent id, an integer, which the path is relative to. + :param (str) user: The username to target files for. Defaults to requestor. + :param (bool) exists: If True, don't return the full file, just a flag. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, route='lookup') + params = make_params(path=path, parent=parent, user=user, exists=exists) + return request('get', url, params=params) + + +def col_create(fid, body): + """ + Create a new column (or columns) inside a grid. + + :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`. + :param (dict) body: A mapping of body param names to values. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, id=fid, route='col') + return request('post', url, json=body) + + +def col_retrieve(fid, uid): + """ + Retrieve a column (or columns) from a grid. + + :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`. + :param (str) uid: A ','-concatenated string of column uids in the grid. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, id=fid, route='col') + params = make_params(uid=uid) + return request('get', url, params=params) + + +def col_update(fid, uid, body): + """ + Update a column (or columns) from a grid. + + :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`. + :param (str) uid: A ','-concatenated string of column uids in the grid. + :param (dict) body: A mapping of body param names to values. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, id=fid, route='col') + params = make_params(uid=uid) + return request('put', url, json=body, params=params) + + +def col_delete(fid, uid): + """ + Permanently delete a column (or columns) from a grid. + + :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`. + :param (str) uid: A ','-concatenated string of column uids in the grid. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, id=fid, route='col') + params = make_params(uid=uid) + return request('delete', url, params=params) + + +def row(fid, body): + """ + Append rows to a grid. + + :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`. + :param (dict) body: A mapping of body param names to values. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, id=fid, route='row') + return request('post', url, json=body) diff --git a/plotly/api/v2/images.py b/plotly/api/v2/images.py new file mode 100644 index 00000000000..4c9d1816081 --- /dev/null +++ b/plotly/api/v2/images.py @@ -0,0 +1,18 @@ +"""Interface to Plotly's /v2/images endpoints.""" +from __future__ import absolute_import + +from plotly.api.v2.utils import build_url, request + +RESOURCE = 'images' + + +def create(body): + """ + Generate an image (which does not get saved on Plotly). + + :param (dict) body: A mapping of body param names to values. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE) + return request('post', url, json=body) diff --git a/plotly/api/v2/plot_schema.py b/plotly/api/v2/plot_schema.py new file mode 100644 index 00000000000..4edbc0a707b --- /dev/null +++ b/plotly/api/v2/plot_schema.py @@ -0,0 +1,19 @@ +"""Interface to Plotly's /v2/plot-schema endpoints.""" +from __future__ import absolute_import + +from plotly.api.v2.utils import build_url, make_params, request + +RESOURCE = 'plot-schema' + + +def retrieve(sha1, **kwargs): + """ + Retrieve the most up-to-date copy of the plot-schema wrt the given hash. + + :param (str) sha1: The last-known hash of the plot-schema (or ''). + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE) + params = make_params(sha1=sha1) + return request('get', url, params=params, **kwargs) diff --git a/plotly/api/v2/plots.py b/plotly/api/v2/plots.py new file mode 100644 index 00000000000..da9f2d9e395 --- /dev/null +++ b/plotly/api/v2/plots.py @@ -0,0 +1,119 @@ +"""Interface to Plotly's /v2/plots endpoints.""" +from __future__ import absolute_import + +from plotly.api.v2.utils import build_url, make_params, request + +RESOURCE = 'plots' + + +def create(body): + """ + Create a new plot. + + :param (dict) body: A mapping of body param names to values. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE) + return request('post', url, json=body) + + +def retrieve(fid, share_key=None): + """ + Retrieve a plot from Plotly. + + :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`. + :param (str) share_key: The secret key granting 'read' access if private. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, id=fid) + params = make_params(share_key=share_key) + return request('get', url, params=params) + + +def content(fid, share_key=None, inline_data=None, map_data=None): + """ + Retrieve the *figure* for a Plotly plot file. + + :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`. + :param (str) share_key: The secret key granting 'read' access if private. + :param (bool) inline_data: If True, include the data arrays with the plot. + :param (str) map_data: Currently only accepts 'unreadable' to return a + mapping of grid-fid: grid. This is useful if you + want to maintain structure between the plot and + referenced grids when you have READ access to the + plot, but you don't have READ access to the + underlying grids. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, id=fid, route='content') + params = make_params(share_key=share_key, inline_data=inline_data, + map_data=map_data) + return request('get', url, params=params) + + +def update(fid, body): + """ + Update a plot from Plotly. + + :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`. + :param (dict) body: A mapping of body param names to values. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, id=fid) + return request('put', url, json=body) + + +def trash(fid): + """ + Soft-delete a plot from Plotly. (Can be undone with 'restore'). + + :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, id=fid, route='trash') + return request('post', url) + + +def restore(fid): + """ + Restore a trashed plot from Plotly. See 'trash'. + + :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, id=fid, route='restore') + return request('post', url) + + +def permanent_delete(fid, params=None): + """ + Permanently delete a trashed plot file from Plotly. See 'trash'. + + :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, id=fid, route='permanent_delete') + return request('delete', url, params=params) + + +def lookup(path, parent=None, user=None, exists=None): + """ + Retrieve a plot file from Plotly without needing a fid. + + :param (str) path: The '/'-delimited path specifying the file location. + :param (int) parent: Parent id, an integer, which the path is relative to. + :param (str) user: The username to target files for. Defaults to requestor. + :param (bool) exists: If True, don't return the full file, just a flag. + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, route='lookup') + params = make_params(path=path, parent=parent, user=user, exists=exists) + return request('get', url, params=params) diff --git a/plotly/api/v2/users.py b/plotly/api/v2/users.py new file mode 100644 index 00000000000..cdfaf51c488 --- /dev/null +++ b/plotly/api/v2/users.py @@ -0,0 +1,17 @@ +"""Interface to Plotly's /v2/files endpoints.""" +from __future__ import absolute_import + +from plotly.api.v2.utils import build_url, request + +RESOURCE = 'users' + + +def current(): + """ + Retrieve information on the logged-in user from Plotly. + + :returns: (requests.Response) Returns response directly from requests. + + """ + url = build_url(RESOURCE, route='current') + return request('get', url) diff --git a/plotly/api/v2/utils.py b/plotly/api/v2/utils.py new file mode 100644 index 00000000000..21bd0ddd016 --- /dev/null +++ b/plotly/api/v2/utils.py @@ -0,0 +1,154 @@ +from __future__ import absolute_import + +import requests +from requests.compat import json as _json +from requests.exceptions import RequestException + +from plotly import config, exceptions, version, utils +from plotly.api.utils import basic_auth + + +def make_params(**kwargs): + """ + Helper to create a params dict, skipping undefined entries. + + :returns: (dict) A params dict to pass to `request`. + + """ + return {k: v for k, v in kwargs.items() if v is not None} + + +def build_url(resource, id='', route=''): + """ + Create a url for a request on a V2 resource. + + :param (str) resource: E.g., 'files', 'plots', 'grids', etc. + :param (str) id: The unique identifier for the resource. + :param (str) route: Detail/list route. E.g., 'restore', 'lookup', etc. + :return: (str) The url. + + """ + base = config.get_config()['plotly_api_domain'] + formatter = {'base': base, 'resource': resource, 'id': id, 'route': route} + + # Add path to base url depending on the input params. Note that `route` + # can refer to a 'list' or a 'detail' route. Since it cannot refer to + # both at the same time, it's overloaded in this function. + if id: + if route: + url = '{base}/v2/{resource}/{id}/{route}'.format(**formatter) + else: + url = '{base}/v2/{resource}/{id}'.format(**formatter) + else: + if route: + url = '{base}/v2/{resource}/{route}'.format(**formatter) + else: + url = '{base}/v2/{resource}'.format(**formatter) + + return url + + +def validate_response(response): + """ + Raise a helpful PlotlyRequestError for failed requests. + + :param (requests.Response) response: A Response object from an api request. + :raises: (PlotlyRequestError) If the request failed for any reason. + :returns: (None) + + """ + if response.ok: + return + + content = response.content + status_code = response.status_code + try: + parsed_content = response.json() + except ValueError: + message = content if content else 'No Content' + raise exceptions.PlotlyRequestError(message, status_code, content) + + message = '' + if isinstance(parsed_content, dict): + errors = parsed_content.get('errors', []) + messages = [error.get('message') for error in errors] + message = '\n'.join([msg for msg in messages if msg]) + if not message: + message = content if content else 'No Content' + + raise exceptions.PlotlyRequestError(message, status_code, content) + + +def get_headers(): + """ + Using session credentials/config, get headers for a V2 API request. + + Users may have their own proxy layer and so we free up the `authorization` + header for this purpose (instead adding the user authorization in a new + `plotly-authorization` header). See pull #239. + + :returns: (dict) Headers to add to a requests.request call. + + """ + creds = config.get_credentials() + + headers = { + 'plotly-client-platform': 'python {}'.format(version.__version__), + 'content-type': 'application/json' + } + + plotly_auth = basic_auth(creds['username'], creds['api_key']) + proxy_auth = basic_auth(creds['proxy_username'], creds['proxy_password']) + + if config.get_config()['plotly_proxy_authorization']: + headers['authorization'] = proxy_auth + if creds['username'] and creds['api_key']: + headers['plotly-authorization'] = plotly_auth + else: + if creds['username'] and creds['api_key']: + headers['authorization'] = plotly_auth + + return headers + + +def request(method, url, **kwargs): + """ + Central place to make any api v2 api request. + + :param (str) method: The request method ('get', 'put', 'delete', ...). + :param (str) url: The full api url to make the request to. + :param kwargs: These are passed along (but possibly mutated) to requests. + :return: (requests.Response) The response directly from requests. + + """ + kwargs['headers'] = dict(kwargs.get('headers', {}), **get_headers()) + + # Change boolean params to lowercase strings. E.g., `True` --> `'true'`. + # Just change the value so that requests handles query string creation. + if isinstance(kwargs.get('params'), dict): + kwargs['params'] = kwargs['params'].copy() + for key in kwargs['params']: + if isinstance(kwargs['params'][key], bool): + kwargs['params'][key] = _json.dumps(kwargs['params'][key]) + + # We have a special json encoding class for non-native objects. + if kwargs.get('json') is not None: + if kwargs.get('data'): + raise exceptions.PlotlyError('Cannot supply data and json kwargs.') + kwargs['data'] = _json.dumps(kwargs.pop('json'), sort_keys=True, + cls=utils.PlotlyJSONEncoder) + + # The config file determines whether reuqests should *verify*. + kwargs['verify'] = config.get_config()['plotly_ssl_verification'] + + try: + response = requests.request(method, url, **kwargs) + except RequestException as e: + # The message can be an exception. E.g., MaxRetryError. + message = str(getattr(e, 'message', 'No message')) + response = getattr(e, 'response', None) + status_code = response.status_code if response else None + content = response.content if response else 'No content' + raise exceptions.PlotlyRequestError(message, status_code, content) + validate_response(response) + return response diff --git a/plotly/config.py b/plotly/config.py new file mode 100644 index 00000000000..dc1b8e28654 --- /dev/null +++ b/plotly/config.py @@ -0,0 +1,35 @@ +""" +Merges and prioritizes file/session config and credentials. + +This is promoted to its own module to simplify imports. + +""" +from __future__ import absolute_import + +from plotly import session, tools + + +def get_credentials(): + """Returns the credentials that will be sent to plotly.""" + credentials = tools.get_credentials_file() + session_credentials = session.get_session_credentials() + for credentials_key in credentials: + + # checking for not false, but truthy value here is the desired behavior + session_value = session_credentials.get(credentials_key) + if session_value is False or session_value: + credentials[credentials_key] = session_value + return credentials + + +def get_config(): + """Returns either module config or file config.""" + config = tools.get_config_file() + session_config = session.get_session_config() + for config_key in config: + + # checking for not false, but truthy value here is the desired behavior + session_value = session_config.get(config_key) + if session_value is False or session_value: + config[config_key] = session_value + return config diff --git a/plotly/exceptions.py b/plotly/exceptions.py index 8f7c8920454..05df864497f 100644 --- a/plotly/exceptions.py +++ b/plotly/exceptions.py @@ -5,7 +5,9 @@ A module that contains plotly's exception hierarchy. """ -import json +from __future__ import absolute_import + +from plotly.api.utils import to_native_utf8_string # Base Plotly Error @@ -18,29 +20,12 @@ class InputError(PlotlyError): class PlotlyRequestError(PlotlyError): - def __init__(self, requests_exception): - self.status_code = requests_exception.response.status_code - self.HTTPError = requests_exception - content_type = requests_exception.response.headers['content-type'] - if 'json' in content_type: - content = requests_exception.response.content - if content != '': - res_payload = json.loads( - requests_exception.response.content.decode('utf8') - ) - if 'detail' in res_payload: - self.message = res_payload['detail'] - else: - self.message = '' - else: - self.message = '' - elif content_type == 'text/plain': - self.message = requests_exception.response.content - else: - try: - self.message = requests_exception.message - except AttributeError: - self.message = 'unknown error' + """General API error. Raised for *all* failed requests.""" + + def __init__(self, message, status_code, content): + self.message = to_native_utf8_string(message) + self.status_code = status_code + self.content = content def __str__(self): return self.message diff --git a/plotly/figure_factory/_2d_density.py b/plotly/figure_factory/_2d_density.py new file mode 100644 index 00000000000..3be0bb06af1 --- /dev/null +++ b/plotly/figure_factory/_2d_density.py @@ -0,0 +1,166 @@ +from __future__ import absolute_import + +from numbers import Number + +from plotly import exceptions +from plotly.figure_factory import utils +from plotly.graph_objs import graph_objs + + +def make_linear_colorscale(colors): + """ + Makes a list of colors into a colorscale-acceptable form + + For documentation regarding to the form of the output, see + https://plot.ly/python/reference/#mesh3d-colorscale + """ + scale = 1. / (len(colors) - 1) + return [[i * scale, color] for i, color in enumerate(colors)] + + +def create_2d_density(x, y, colorscale='Earth', ncontours=20, + hist_color=(0, 0, 0.5), point_color=(0, 0, 0.5), + point_size=2, title='2D Density Plot', + height=600, width=600): + """ + Returns figure for a 2D density plot + + :param (list|array) x: x-axis data for plot generation + :param (list|array) y: y-axis data for plot generation + :param (str|tuple|list) colorscale: either a plotly scale name, an rgb + or hex color, a color tuple or a list or tuple of colors. An rgb + color is of the form 'rgb(x, y, z)' where x, y, z belong to the + interval [0, 255] and a color tuple is a tuple of the form + (a, b, c) where a, b and c belong to [0, 1]. If colormap is a + list, it must contain the valid color types aforementioned as its + members. + :param (int) ncontours: the number of 2D contours to draw on the plot + :param (str) hist_color: the color of the plotted histograms + :param (str) point_color: the color of the scatter points + :param (str) point_size: the color of the scatter points + :param (str) title: set the title for the plot + :param (float) height: the height of the chart + :param (float) width: the width of the chart + + Example 1: Simple 2D Density Plot + ``` + import plotly.plotly as py + from plotly.figure_factory create_2d_density + + import numpy as np + + # Make data points + t = np.linspace(-1,1.2,2000) + x = (t**3)+(0.3*np.random.randn(2000)) + y = (t**6)+(0.3*np.random.randn(2000)) + + # Create a figure + fig = create_2D_density(x, y) + + # Plot the data + py.iplot(fig, filename='simple-2d-density') + ``` + + Example 2: Using Parameters + ``` + import plotly.plotly as py + from plotly.figure_factory create_2d_density + + import numpy as np + + # Make data points + t = np.linspace(-1,1.2,2000) + x = (t**3)+(0.3*np.random.randn(2000)) + y = (t**6)+(0.3*np.random.randn(2000)) + + # Create custom colorscale + colorscale = ['#7A4579', '#D56073', 'rgb(236,158,105)', + (1, 1, 0.2), (0.98,0.98,0.98)] + + # Create a figure + fig = create_2D_density( + x, y, colorscale=colorscale, + hist_color='rgb(255, 237, 222)', point_size=3) + + # Plot the data + py.iplot(fig, filename='use-parameters') + ``` + """ + + # validate x and y are filled with numbers only + for array in [x, y]: + if not all(isinstance(element, Number) for element in array): + raise exceptions.PlotlyError( + "All elements of your 'x' and 'y' lists must be numbers." + ) + + # validate x and y are the same length + if len(x) != len(y): + raise exceptions.PlotlyError( + "Both lists 'x' and 'y' must be the same length." + ) + + colorscale = utils.validate_colors(colorscale, 'rgb') + colorscale = make_linear_colorscale(colorscale) + + # validate hist_color and point_color + hist_color = utils.validate_colors(hist_color, 'rgb') + point_color = utils.validate_colors(point_color, 'rgb') + + trace1 = graph_objs.Scatter( + x=x, y=y, mode='markers', name='points', + marker=dict( + color=point_color[0], + size=point_size, + opacity=0.4 + ) + ) + trace2 = graph_objs.Histogram2dcontour( + x=x, y=y, name='density', ncontours=ncontours, + colorscale=colorscale, reversescale=True, showscale=False + ) + trace3 = graph_objs.Histogram( + x=x, name='x density', + marker=dict(color=hist_color[0]), yaxis='y2' + ) + trace4 = graph_objs.Histogram( + y=y, name='y density', + marker=dict(color=hist_color[0]), xaxis='x2' + ) + data = [trace1, trace2, trace3, trace4] + + layout = graph_objs.Layout( + showlegend=False, + autosize=False, + title=title, + height=height, + width=width, + xaxis=dict( + domain=[0, 0.85], + showgrid=False, + zeroline=False + ), + yaxis=dict( + domain=[0, 0.85], + showgrid=False, + zeroline=False + ), + margin=dict( + t=50 + ), + hovermode='closest', + bargap=0, + xaxis2=dict( + domain=[0.85, 1], + showgrid=False, + zeroline=False + ), + yaxis2=dict( + domain=[0.85, 1], + showgrid=False, + zeroline=False + ) + ) + + fig = graph_objs.Figure(data=data, layout=layout) + return fig diff --git a/plotly/figure_factory/__init__.py b/plotly/figure_factory/__init__.py new file mode 100644 index 00000000000..153ac6d657b --- /dev/null +++ b/plotly/figure_factory/__init__.py @@ -0,0 +1,18 @@ +from __future__ import absolute_import + +# Require that numpy exists for figure_factory +import numpy + +from plotly.figure_factory._2d_density import create_2d_density +from plotly.figure_factory._annotated_heatmap import create_annotated_heatmap +from plotly.figure_factory._candlestick import create_candlestick +from plotly.figure_factory._dendrogram import create_dendrogram +from plotly.figure_factory._distplot import create_distplot +from plotly.figure_factory._gantt import create_gantt +from plotly.figure_factory._ohlc import create_ohlc +from plotly.figure_factory._quiver import create_quiver +from plotly.figure_factory._scatterplot import create_scatterplotmatrix +from plotly.figure_factory._streamline import create_streamline +from plotly.figure_factory._table import create_table +from plotly.figure_factory._trisurf import create_trisurf +from plotly.figure_factory._violin import create_violin diff --git a/plotly/figure_factory/_annotated_heatmap.py b/plotly/figure_factory/_annotated_heatmap.py new file mode 100644 index 00000000000..36fc1ff7c16 --- /dev/null +++ b/plotly/figure_factory/_annotated_heatmap.py @@ -0,0 +1,239 @@ +from __future__ import absolute_import + +from plotly import exceptions, optional_imports +from plotly.figure_factory import utils +from plotly.graph_objs import graph_objs + +# Optional imports, may be None for users that only use our core functionality. +np = optional_imports.get_module('numpy') + + +def validate_annotated_heatmap(z, x, y, annotation_text): + """ + Annotated-heatmap-specific validations + + Check that if a text matrix is supplied, it has the same + dimensions as the z matrix. + + See FigureFactory.create_annotated_heatmap() for params + + :raises: (PlotlyError) If z and text matrices do not have the same + dimensions. + """ + if annotation_text is not None and isinstance(annotation_text, list): + utils.validate_equal_length(z, annotation_text) + for lst in range(len(z)): + if len(z[lst]) != len(annotation_text[lst]): + raise exceptions.PlotlyError("z and text should have the " + "same dimensions") + + if x: + if len(x) != len(z[0]): + raise exceptions.PlotlyError("oops, the x list that you " + "provided does not match the " + "width of your z matrix ") + + if y: + if len(y) != len(z): + raise exceptions.PlotlyError("oops, the y list that you " + "provided does not match the " + "length of your z matrix ") + + +def create_annotated_heatmap(z, x=None, y=None, annotation_text=None, + colorscale='RdBu', font_colors=None, + showscale=False, reversescale=False, + **kwargs): + """ + BETA function that creates annotated heatmaps + + This function adds annotations to each cell of the heatmap. + + :param (list[list]|ndarray) z: z matrix to create heatmap. + :param (list) x: x axis labels. + :param (list) y: y axis labels. + :param (list[list]|ndarray) annotation_text: Text strings for + annotations. Should have the same dimensions as the z matrix. If no + text is added, the values of the z matrix are annotated. Default = + z matrix values. + :param (list|str) colorscale: heatmap colorscale. + :param (list) font_colors: List of two color strings: [min_text_color, + max_text_color] where min_text_color is applied to annotations for + heatmap values < (max_value - min_value)/2. If font_colors is not + defined, the colors are defined logically as black or white + depending on the heatmap's colorscale. + :param (bool) showscale: Display colorscale. Default = False + :param kwargs: kwargs passed through plotly.graph_objs.Heatmap. + These kwargs describe other attributes about the annotated Heatmap + trace such as the colorscale. For more information on valid kwargs + call help(plotly.graph_objs.Heatmap) + + Example 1: Simple annotated heatmap with default configuration + ``` + import plotly.plotly as py + from plotly.figure_factory create_annotated_heatmap + + z = [[0.300000, 0.00000, 0.65, 0.300000], + [1, 0.100005, 0.45, 0.4300], + [0.300000, 0.00000, 0.65, 0.300000], + [1, 0.100005, 0.45, 0.00000]] + + figure = create_annotated_heatmap(z) + py.iplot(figure) + ``` + """ + + # Avoiding mutables in the call signature + font_colors = font_colors if font_colors is not None else [] + validate_annotated_heatmap(z, x, y, annotation_text) + annotations = _AnnotatedHeatmap(z, x, y, annotation_text, + colorscale, font_colors, reversescale, + **kwargs).make_annotations() + + if x or y: + trace = dict(type='heatmap', z=z, x=x, y=y, colorscale=colorscale, + showscale=showscale, **kwargs) + layout = dict(annotations=annotations, + xaxis=dict(ticks='', dtick=1, side='top', + gridcolor='rgb(0, 0, 0)'), + yaxis=dict(ticks='', dtick=1, ticksuffix=' ')) + else: + trace = dict(type='heatmap', z=z, colorscale=colorscale, + showscale=showscale, **kwargs) + layout = dict(annotations=annotations, + xaxis=dict(ticks='', side='top', + gridcolor='rgb(0, 0, 0)', + showticklabels=False), + yaxis=dict(ticks='', ticksuffix=' ', + showticklabels=False)) + + data = [trace] + + return graph_objs.Figure(data=data, layout=layout) + + +class _AnnotatedHeatmap(object): + """ + Refer to TraceFactory.create_annotated_heatmap() for docstring + """ + def __init__(self, z, x, y, annotation_text, colorscale, + font_colors, reversescale, **kwargs): + + self.z = z + if x: + self.x = x + else: + self.x = range(len(z[0])) + if y: + self.y = y + else: + self.y = range(len(z)) + if annotation_text is not None: + self.annotation_text = annotation_text + else: + self.annotation_text = self.z + self.colorscale = colorscale + self.reversescale = reversescale + self.font_colors = font_colors + + def get_text_color(self): + """ + Get font color for annotations. + + The annotated heatmap can feature two text colors: min_text_color and + max_text_color. The min_text_color is applied to annotations for + heatmap values < (max_value - min_value)/2. The user can define these + two colors. Otherwise the colors are defined logically as black or + white depending on the heatmap's colorscale. + + :rtype (string, string) min_text_color, max_text_color: text + color for annotations for heatmap values < + (max_value - min_value)/2 and text color for annotations for + heatmap values >= (max_value - min_value)/2 + """ + # Plotly colorscales ranging from a lighter shade to a darker shade + colorscales = ['Greys', 'Greens', 'Blues', + 'YIGnBu', 'YIOrRd', 'RdBu', + 'Picnic', 'Jet', 'Hot', 'Blackbody', + 'Earth', 'Electric', 'Viridis'] + # Plotly colorscales ranging from a darker shade to a lighter shade + colorscales_reverse = ['Reds'] + if self.font_colors: + min_text_color = self.font_colors[0] + max_text_color = self.font_colors[-1] + elif self.colorscale in colorscales and self.reversescale: + min_text_color = '#000000' + max_text_color = '#FFFFFF' + elif self.colorscale in colorscales: + min_text_color = '#FFFFFF' + max_text_color = '#000000' + elif self.colorscale in colorscales_reverse and self.reversescale: + min_text_color = '#FFFFFF' + max_text_color = '#000000' + elif self.colorscale in colorscales_reverse: + min_text_color = '#000000' + max_text_color = '#FFFFFF' + elif isinstance(self.colorscale, list): + if 'rgb' in self.colorscale[0][1]: + min_col = map(int, + self.colorscale[0][1].strip('rgb()').split(',')) + max_col = map(int, + self.colorscale[-1][1].strip('rgb()').split(',')) + elif '#' in self.colorscale[0][1]: + min_col = utils.hex_to_rgb(self.colorscale[0][1]) + max_col = utils.hex_to_rgb(self.colorscale[-1][1]) + else: + min_col = [255, 255, 255] + max_col = [255, 255, 255] + + if (min_col[0]*0.299 + min_col[1]*0.587 + min_col[2]*0.114) > 186: + min_text_color = '#000000' + else: + min_text_color = '#FFFFFF' + if (max_col[0]*0.299 + max_col[1]*0.587 + max_col[2]*0.114) > 186: + max_text_color = '#000000' + else: + max_text_color = '#FFFFFF' + else: + min_text_color = '#000000' + max_text_color = '#000000' + return min_text_color, max_text_color + + def get_z_mid(self): + """ + Get the mid value of z matrix + + :rtype (float) z_avg: average val from z matrix + """ + if np and isinstance(self.z, np.ndarray): + z_min = np.amin(self.z) + z_max = np.amax(self.z) + else: + z_min = min(min(self.z)) + z_max = max(max(self.z)) + z_mid = (z_max+z_min) / 2 + return z_mid + + def make_annotations(self): + """ + Get annotations for each cell of the heatmap with graph_objs.Annotation + + :rtype (list[dict]) annotations: list of annotations for each cell of + the heatmap + """ + min_text_color, max_text_color = _AnnotatedHeatmap.get_text_color(self) + z_mid = _AnnotatedHeatmap.get_z_mid(self) + annotations = [] + for n, row in enumerate(self.z): + for m, val in enumerate(row): + font_color = min_text_color if val < z_mid else max_text_color + annotations.append( + graph_objs.Annotation( + text=str(self.annotation_text[n][m]), + x=self.x[m], + y=self.y[n], + xref='x1', + yref='y1', + font=dict(color=font_color), + showarrow=False)) + return annotations diff --git a/plotly/figure_factory/_candlestick.py b/plotly/figure_factory/_candlestick.py new file mode 100644 index 00000000000..925b4c1a62b --- /dev/null +++ b/plotly/figure_factory/_candlestick.py @@ -0,0 +1,294 @@ +from __future__ import absolute_import + +from plotly.figure_factory import utils +from plotly.figure_factory._ohlc import (_DEFAULT_INCREASING_COLOR, + _DEFAULT_DECREASING_COLOR, + validate_ohlc) +from plotly.graph_objs import graph_objs + + +def make_increasing_candle(open, high, low, close, dates, **kwargs): + """ + Makes boxplot trace for increasing candlesticks + + _make_increasing_candle() and _make_decreasing_candle separate the + increasing traces from the decreasing traces so kwargs (such as + color) can be passed separately to increasing or decreasing traces + when direction is set to 'increasing' or 'decreasing' in + FigureFactory.create_candlestick() + + :param (list) open: opening values + :param (list) high: high values + :param (list) low: low values + :param (list) close: closing values + :param (list) dates: list of datetime objects. Default: None + :param kwargs: kwargs to be passed to increasing trace via + plotly.graph_objs.Scatter. + + :rtype (list) candle_incr_data: list of the box trace for + increasing candlesticks. + """ + increase_x, increase_y = _Candlestick( + open, high, low, close, dates, **kwargs).get_candle_increase() + + if 'line' in kwargs: + kwargs.setdefault('fillcolor', kwargs['line']['color']) + else: + kwargs.setdefault('fillcolor', _DEFAULT_INCREASING_COLOR) + if 'name' in kwargs: + kwargs.setdefault('showlegend', True) + else: + kwargs.setdefault('showlegend', False) + kwargs.setdefault('name', 'Increasing') + kwargs.setdefault('line', dict(color=_DEFAULT_INCREASING_COLOR)) + + candle_incr_data = dict(type='box', + x=increase_x, + y=increase_y, + whiskerwidth=0, + boxpoints=False, + **kwargs) + + return [candle_incr_data] + + +def make_decreasing_candle(open, high, low, close, dates, **kwargs): + """ + Makes boxplot trace for decreasing candlesticks + + :param (list) open: opening values + :param (list) high: high values + :param (list) low: low values + :param (list) close: closing values + :param (list) dates: list of datetime objects. Default: None + :param kwargs: kwargs to be passed to decreasing trace via + plotly.graph_objs.Scatter. + + :rtype (list) candle_decr_data: list of the box trace for + decreasing candlesticks. + """ + + decrease_x, decrease_y = _Candlestick( + open, high, low, close, dates, **kwargs).get_candle_decrease() + + if 'line' in kwargs: + kwargs.setdefault('fillcolor', kwargs['line']['color']) + else: + kwargs.setdefault('fillcolor', _DEFAULT_DECREASING_COLOR) + kwargs.setdefault('showlegend', False) + kwargs.setdefault('line', dict(color=_DEFAULT_DECREASING_COLOR)) + kwargs.setdefault('name', 'Decreasing') + + candle_decr_data = dict(type='box', + x=decrease_x, + y=decrease_y, + whiskerwidth=0, + boxpoints=False, + **kwargs) + + return [candle_decr_data] + + +def create_candlestick(open, high, low, close, dates=None, direction='both', + **kwargs): + """ + BETA function that creates a candlestick chart + + :param (list) open: opening values + :param (list) high: high values + :param (list) low: low values + :param (list) close: closing values + :param (list) dates: list of datetime objects. Default: None + :param (string) direction: direction can be 'increasing', 'decreasing', + or 'both'. When the direction is 'increasing', the returned figure + consists of all candlesticks where the close value is greater than + the corresponding open value, and when the direction is + 'decreasing', the returned figure consists of all candlesticks + where the close value is less than or equal to the corresponding + open value. When the direction is 'both', both increasing and + decreasing candlesticks are returned. Default: 'both' + :param kwargs: kwargs passed through plotly.graph_objs.Scatter. + These kwargs describe other attributes about the ohlc Scatter trace + such as the color or the legend name. For more information on valid + kwargs call help(plotly.graph_objs.Scatter) + + :rtype (dict): returns a representation of candlestick chart figure. + + Example 1: Simple candlestick chart from a Pandas DataFrame + ``` + import plotly.plotly as py + from plotly.figure_factory import create_candlestick + from datetime import datetime + + import pandas.io.data as web + + df = web.DataReader("aapl", 'yahoo', datetime(2007, 10, 1), datetime(2009, 4, 1)) + fig = create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index) + py.plot(fig, filename='finance/aapl-candlestick', validate=False) + ``` + + Example 2: Add text and annotations to the candlestick chart + ``` + fig = create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index) + # Update the fig - all options here: https://plot.ly/python/reference/#Layout + fig['layout'].update({ + 'title': 'The Great Recession', + 'yaxis': {'title': 'AAPL Stock'}, + 'shapes': [{ + 'x0': '2007-12-01', 'x1': '2007-12-01', + 'y0': 0, 'y1': 1, 'xref': 'x', 'yref': 'paper', + 'line': {'color': 'rgb(30,30,30)', 'width': 1} + }], + 'annotations': [{ + 'x': '2007-12-01', 'y': 0.05, 'xref': 'x', 'yref': 'paper', + 'showarrow': False, 'xanchor': 'left', + 'text': 'Official start of the recession' + }] + }) + py.plot(fig, filename='finance/aapl-recession-candlestick', validate=False) + ``` + + Example 3: Customize the candlestick colors + ``` + import plotly.plotly as py + from plotly.figure_factory import create_candlestick + from plotly.graph_objs import Line, Marker + from datetime import datetime + + import pandas.io.data as web + + df = web.DataReader("aapl", 'yahoo', datetime(2008, 1, 1), datetime(2009, 4, 1)) + + # Make increasing candlesticks and customize their color and name + fig_increasing = create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index, + direction='increasing', name='AAPL', + marker=Marker(color='rgb(150, 200, 250)'), + line=Line(color='rgb(150, 200, 250)')) + + # Make decreasing candlesticks and customize their color and name + fig_decreasing = create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index, + direction='decreasing', + marker=Marker(color='rgb(128, 128, 128)'), + line=Line(color='rgb(128, 128, 128)')) + + # Initialize the figure + fig = fig_increasing + + # Add decreasing data with .extend() + fig['data'].extend(fig_decreasing['data']) + + py.iplot(fig, filename='finance/aapl-candlestick-custom', validate=False) + ``` + + Example 4: Candlestick chart with datetime objects + ``` + import plotly.plotly as py + from plotly.figure_factory import create_candlestick + + from datetime import datetime + + # Add data + open_data = [33.0, 33.3, 33.5, 33.0, 34.1] + high_data = [33.1, 33.3, 33.6, 33.2, 34.8] + low_data = [32.7, 32.7, 32.8, 32.6, 32.8] + close_data = [33.0, 32.9, 33.3, 33.1, 33.1] + dates = [datetime(year=2013, month=10, day=10), + datetime(year=2013, month=11, day=10), + datetime(year=2013, month=12, day=10), + datetime(year=2014, month=1, day=10), + datetime(year=2014, month=2, day=10)] + + # Create ohlc + fig = create_candlestick(open_data, high_data, + low_data, close_data, dates=dates) + + py.iplot(fig, filename='finance/simple-candlestick', validate=False) + ``` + """ + if dates is not None: + utils.validate_equal_length(open, high, low, close, dates) + else: + utils.validate_equal_length(open, high, low, close) + validate_ohlc(open, high, low, close, direction, **kwargs) + + if direction is 'increasing': + candle_incr_data = make_increasing_candle(open, high, low, close, + dates, **kwargs) + data = candle_incr_data + elif direction is 'decreasing': + candle_decr_data = make_decreasing_candle(open, high, low, close, + dates, **kwargs) + data = candle_decr_data + else: + candle_incr_data = make_increasing_candle(open, high, low, close, + dates, **kwargs) + candle_decr_data = make_decreasing_candle(open, high, low, close, + dates, **kwargs) + data = candle_incr_data + candle_decr_data + + layout = graph_objs.Layout() + return graph_objs.Figure(data=data, layout=layout) + + +class _Candlestick(object): + """ + Refer to FigureFactory.create_candlestick() for docstring. + """ + def __init__(self, open, high, low, close, dates, **kwargs): + self.open = open + self.high = high + self.low = low + self.close = close + if dates is not None: + self.x = dates + else: + self.x = [x for x in range(len(self.open))] + self.get_candle_increase() + + def get_candle_increase(self): + """ + Separate increasing data from decreasing data. + + The data is increasing when close value > open value + and decreasing when the close value <= open value. + """ + increase_y = [] + increase_x = [] + for index in range(len(self.open)): + if self.close[index] > self.open[index]: + increase_y.append(self.low[index]) + increase_y.append(self.open[index]) + increase_y.append(self.close[index]) + increase_y.append(self.close[index]) + increase_y.append(self.close[index]) + increase_y.append(self.high[index]) + increase_x.append(self.x[index]) + + increase_x = [[x, x, x, x, x, x] for x in increase_x] + increase_x = utils.flatten(increase_x) + + return increase_x, increase_y + + def get_candle_decrease(self): + """ + Separate increasing data from decreasing data. + + The data is increasing when close value > open value + and decreasing when the close value <= open value. + """ + decrease_y = [] + decrease_x = [] + for index in range(len(self.open)): + if self.close[index] <= self.open[index]: + decrease_y.append(self.low[index]) + decrease_y.append(self.open[index]) + decrease_y.append(self.close[index]) + decrease_y.append(self.close[index]) + decrease_y.append(self.close[index]) + decrease_y.append(self.high[index]) + decrease_x.append(self.x[index]) + + decrease_x = [[x, x, x, x, x, x] for x in decrease_x] + decrease_x = utils.flatten(decrease_x) + + return decrease_x, decrease_y diff --git a/plotly/figure_factory/_dendrogram.py b/plotly/figure_factory/_dendrogram.py new file mode 100644 index 00000000000..26e6692290d --- /dev/null +++ b/plotly/figure_factory/_dendrogram.py @@ -0,0 +1,304 @@ +# -*- coding: utf-8 -*- + +from __future__ import absolute_import + +from collections import OrderedDict + +from plotly import exceptions, optional_imports +from plotly.graph_objs import graph_objs + +# Optional imports, may be None for users that only use our core functionality. +np = optional_imports.get_module('numpy') +scp = optional_imports.get_module('scipy') +sch = optional_imports.get_module('scipy.cluster.hierarchy') +scs = optional_imports.get_module('scipy.spatial') + + +def create_dendrogram(X, orientation="bottom", labels=None, + colorscale=None, distfun=None, + linkagefun=lambda x: sch.linkage(x, 'complete')): + """ + BETA function that returns a dendrogram Plotly figure object. + + :param (ndarray) X: Matrix of observations as array of arrays + :param (str) orientation: 'top', 'right', 'bottom', or 'left' + :param (list) labels: List of axis category labels(observation labels) + :param (list) colorscale: Optional colorscale for dendrogram tree + :param (function) distfun: Function to compute the pairwise distance from + the observations + :param (function) linkagefun: Function to compute the linkage matrix from + the pairwise distances + + clusters + + Example 1: Simple bottom oriented dendrogram + ``` + import plotly.plotly as py + from plotly.figure_factory import create_dendrogram + + import numpy as np + + X = np.random.rand(10,10) + dendro = create_dendrogram(X) + plot_url = py.plot(dendro, filename='simple-dendrogram') + + ``` + + Example 2: Dendrogram to put on the left of the heatmap + ``` + import plotly.plotly as py + from plotly.figure_factory import create_dendrogram + + import numpy as np + + X = np.random.rand(5,5) + names = ['Jack', 'Oxana', 'John', 'Chelsea', 'Mark'] + dendro = create_dendrogram(X, orientation='right', labels=names) + dendro['layout'].update({'width':700, 'height':500}) + + py.iplot(dendro, filename='vertical-dendrogram') + ``` + + Example 3: Dendrogram with Pandas + ``` + import plotly.plotly as py + from plotly.figure_factory import create_dendrogram + + import numpy as np + import pandas as pd + + Index= ['A','B','C','D','E','F','G','H','I','J'] + df = pd.DataFrame(abs(np.random.randn(10, 10)), index=Index) + fig = create_dendrogram(df, labels=Index) + url = py.plot(fig, filename='pandas-dendrogram') + ``` + """ + if not scp or not scs or not sch: + raise ImportError("FigureFactory.create_dendrogram requires scipy, \ + scipy.spatial and scipy.hierarchy") + + s = X.shape + if len(s) != 2: + exceptions.PlotlyError("X should be 2-dimensional array.") + + if distfun is None: + distfun = scs.distance.pdist + + dendrogram = _Dendrogram(X, orientation, labels, colorscale, + distfun=distfun, linkagefun=linkagefun) + + return {'layout': dendrogram.layout, + 'data': dendrogram.data} + + +class _Dendrogram(object): + """Refer to FigureFactory.create_dendrogram() for docstring.""" + + def __init__(self, X, orientation='bottom', labels=None, colorscale=None, + width="100%", height="100%", xaxis='xaxis', yaxis='yaxis', + distfun=None, + linkagefun=lambda x: sch.linkage(x, 'complete')): + self.orientation = orientation + self.labels = labels + self.xaxis = xaxis + self.yaxis = yaxis + self.data = [] + self.leaves = [] + self.sign = {self.xaxis: 1, self.yaxis: 1} + self.layout = {self.xaxis: {}, self.yaxis: {}} + + if self.orientation in ['left', 'bottom']: + self.sign[self.xaxis] = 1 + else: + self.sign[self.xaxis] = -1 + + if self.orientation in ['right', 'bottom']: + self.sign[self.yaxis] = 1 + else: + self.sign[self.yaxis] = -1 + + if distfun is None: + distfun = scs.distance.pdist + + (dd_traces, xvals, yvals, + ordered_labels, leaves) = self.get_dendrogram_traces(X, colorscale, + distfun, + linkagefun) + + self.labels = ordered_labels + self.leaves = leaves + yvals_flat = yvals.flatten() + xvals_flat = xvals.flatten() + + self.zero_vals = [] + + for i in range(len(yvals_flat)): + if yvals_flat[i] == 0.0 and xvals_flat[i] not in self.zero_vals: + self.zero_vals.append(xvals_flat[i]) + + self.zero_vals.sort() + + self.layout = self.set_figure_layout(width, height) + self.data = graph_objs.Data(dd_traces) + + def get_color_dict(self, colorscale): + """ + Returns colorscale used for dendrogram tree clusters. + + :param (list) colorscale: Colors to use for the plot in rgb format. + :rtype (dict): A dict of default colors mapped to the user colorscale. + + """ + + # These are the color codes returned for dendrograms + # We're replacing them with nicer colors + d = {'r': 'red', + 'g': 'green', + 'b': 'blue', + 'c': 'cyan', + 'm': 'magenta', + 'y': 'yellow', + 'k': 'black', + 'w': 'white'} + default_colors = OrderedDict(sorted(d.items(), key=lambda t: t[0])) + + if colorscale is None: + colorscale = [ + 'rgb(0,116,217)', # blue + 'rgb(35,205,205)', # cyan + 'rgb(61,153,112)', # green + 'rgb(40,35,35)', # black + 'rgb(133,20,75)', # magenta + 'rgb(255,65,54)', # red + 'rgb(255,255,255)', # white + 'rgb(255,220,0)'] # yellow + + for i in range(len(default_colors.keys())): + k = list(default_colors.keys())[i] # PY3 won't index keys + if i < len(colorscale): + default_colors[k] = colorscale[i] + + return default_colors + + def set_axis_layout(self, axis_key): + """ + Sets and returns default axis object for dendrogram figure. + + :param (str) axis_key: E.g., 'xaxis', 'xaxis1', 'yaxis', yaxis1', etc. + :rtype (dict): An axis_key dictionary with set parameters. + + """ + axis_defaults = { + 'type': 'linear', + 'ticks': 'outside', + 'mirror': 'allticks', + 'rangemode': 'tozero', + 'showticklabels': True, + 'zeroline': False, + 'showgrid': False, + 'showline': True, + } + + if len(self.labels) != 0: + axis_key_labels = self.xaxis + if self.orientation in ['left', 'right']: + axis_key_labels = self.yaxis + if axis_key_labels not in self.layout: + self.layout[axis_key_labels] = {} + self.layout[axis_key_labels]['tickvals'] = \ + [zv*self.sign[axis_key] for zv in self.zero_vals] + self.layout[axis_key_labels]['ticktext'] = self.labels + self.layout[axis_key_labels]['tickmode'] = 'array' + + self.layout[axis_key].update(axis_defaults) + + return self.layout[axis_key] + + def set_figure_layout(self, width, height): + """ + Sets and returns default layout object for dendrogram figure. + + """ + self.layout.update({ + 'showlegend': False, + 'autosize': False, + 'hovermode': 'closest', + 'width': width, + 'height': height + }) + + self.set_axis_layout(self.xaxis) + self.set_axis_layout(self.yaxis) + + return self.layout + + def get_dendrogram_traces(self, X, colorscale, distfun, linkagefun): + """ + Calculates all the elements needed for plotting a dendrogram. + + :param (ndarray) X: Matrix of observations as array of arrays + :param (list) colorscale: Color scale for dendrogram tree clusters + :param (function) distfun: Function to compute the pairwise distance + from the observations + :param (function) linkagefun: Function to compute the linkage matrix + from the pairwise distances + :rtype (tuple): Contains all the traces in the following order: + (a) trace_list: List of Plotly trace objects for dendrogram tree + (b) icoord: All X points of the dendrogram tree as array of arrays + with length 4 + (c) dcoord: All Y points of the dendrogram tree as array of arrays + with length 4 + (d) ordered_labels: leaf labels in the order they are going to + appear on the plot + (e) P['leaves']: left-to-right traversal of the leaves + + """ + d = distfun(X) + Z = linkagefun(d) + P = sch.dendrogram(Z, orientation=self.orientation, + labels=self.labels, no_plot=True) + + icoord = scp.array(P['icoord']) + dcoord = scp.array(P['dcoord']) + ordered_labels = scp.array(P['ivl']) + color_list = scp.array(P['color_list']) + colors = self.get_color_dict(colorscale) + + trace_list = [] + + for i in range(len(icoord)): + # xs and ys are arrays of 4 points that make up the '∩' shapes + # of the dendrogram tree + if self.orientation in ['top', 'bottom']: + xs = icoord[i] + else: + xs = dcoord[i] + + if self.orientation in ['top', 'bottom']: + ys = dcoord[i] + else: + ys = icoord[i] + color_key = color_list[i] + trace = graph_objs.Scatter( + x=np.multiply(self.sign[self.xaxis], xs), + y=np.multiply(self.sign[self.yaxis], ys), + mode='lines', + marker=graph_objs.Marker(color=colors[color_key]) + ) + + try: + x_index = int(self.xaxis[-1]) + except ValueError: + x_index = '' + + try: + y_index = int(self.yaxis[-1]) + except ValueError: + y_index = '' + + trace['xaxis'] = 'x' + x_index + trace['yaxis'] = 'y' + y_index + + trace_list.append(trace) + + return trace_list, icoord, dcoord, ordered_labels, P['leaves'] diff --git a/plotly/figure_factory/_distplot.py b/plotly/figure_factory/_distplot.py new file mode 100644 index 00000000000..d2b77fb0873 --- /dev/null +++ b/plotly/figure_factory/_distplot.py @@ -0,0 +1,390 @@ +from __future__ import absolute_import + +from plotly import exceptions, optional_imports +from plotly.figure_factory import utils +from plotly.graph_objs import graph_objs + +# Optional imports, may be None for users that only use our core functionality. +np = optional_imports.get_module('numpy') +pd = optional_imports.get_module('pandas') +scipy = optional_imports.get_module('scipy') +scipy_stats = optional_imports.get_module('scipy.stats') + + +DEFAULT_HISTNORM = 'probability density' +ALTERNATIVE_HISTNORM = 'probability' + + +def validate_distplot(hist_data, curve_type): + """ + Distplot-specific validations + + :raises: (PlotlyError) If hist_data is not a list of lists + :raises: (PlotlyError) If curve_type is not valid (i.e. not 'kde' or + 'normal'). + """ + hist_data_types = (list,) + if np: + hist_data_types += (np.ndarray,) + if pd: + hist_data_types += (pd.core.series.Series,) + + if not isinstance(hist_data[0], hist_data_types): + raise exceptions.PlotlyError("Oops, this function was written " + "to handle multiple datasets, if " + "you want to plot just one, make " + "sure your hist_data variable is " + "still a list of lists, i.e. x = " + "[1, 2, 3] -> x = [[1, 2, 3]]") + + curve_opts = ('kde', 'normal') + if curve_type not in curve_opts: + raise exceptions.PlotlyError("curve_type must be defined as " + "'kde' or 'normal'") + + if not scipy: + raise ImportError("FigureFactory.create_distplot requires scipy") + + +def create_distplot(hist_data, group_labels, bin_size=1., curve_type='kde', + colors=None, rug_text=None, histnorm=DEFAULT_HISTNORM, + show_hist=True, show_curve=True, show_rug=True): + """ + BETA function that creates a distplot similar to seaborn.distplot + + The distplot can be composed of all or any combination of the following + 3 components: (1) histogram, (2) curve: (a) kernel density estimation + or (b) normal curve, and (3) rug plot. Additionally, multiple distplots + (from multiple datasets) can be created in the same plot. + + :param (list[list]) hist_data: Use list of lists to plot multiple data + sets on the same plot. + :param (list[str]) group_labels: Names for each data set. + :param (list[float]|float) bin_size: Size of histogram bins. + Default = 1. + :param (str) curve_type: 'kde' or 'normal'. Default = 'kde' + :param (str) histnorm: 'probability density' or 'probability' + Default = 'probability density' + :param (bool) show_hist: Add histogram to distplot? Default = True + :param (bool) show_curve: Add curve to distplot? Default = True + :param (bool) show_rug: Add rug to distplot? Default = True + :param (list[str]) colors: Colors for traces. + :param (list[list]) rug_text: Hovertext values for rug_plot, + :return (dict): Representation of a distplot figure. + + Example 1: Simple distplot of 1 data set + ``` + import plotly.plotly as py + from plotly.figure_factory import create_distplot + + hist_data = [[1.1, 1.1, 2.5, 3.0, 3.5, + 3.5, 4.1, 4.4, 4.5, 4.5, + 5.0, 5.0, 5.2, 5.5, 5.5, + 5.5, 5.5, 5.5, 6.1, 7.0]] + + group_labels = ['distplot example'] + + fig = create_distplot(hist_data, group_labels) + + url = py.plot(fig, filename='Simple distplot', validate=False) + ``` + + Example 2: Two data sets and added rug text + ``` + import plotly.plotly as py + from plotly.figure_factory import create_distplot + + # Add histogram data + hist1_x = [0.8, 1.2, 0.2, 0.6, 1.6, + -0.9, -0.07, 1.95, 0.9, -0.2, + -0.5, 0.3, 0.4, -0.37, 0.6] + hist2_x = [0.8, 1.5, 1.5, 0.6, 0.59, + 1.0, 0.8, 1.7, 0.5, 0.8, + -0.3, 1.2, 0.56, 0.3, 2.2] + + # Group data together + hist_data = [hist1_x, hist2_x] + + group_labels = ['2012', '2013'] + + # Add text + rug_text_1 = ['a1', 'b1', 'c1', 'd1', 'e1', + 'f1', 'g1', 'h1', 'i1', 'j1', + 'k1', 'l1', 'm1', 'n1', 'o1'] + + rug_text_2 = ['a2', 'b2', 'c2', 'd2', 'e2', + 'f2', 'g2', 'h2', 'i2', 'j2', + 'k2', 'l2', 'm2', 'n2', 'o2'] + + # Group text together + rug_text_all = [rug_text_1, rug_text_2] + + # Create distplot + fig = create_distplot( + hist_data, group_labels, rug_text=rug_text_all, bin_size=.2) + + # Add title + fig['layout'].update(title='Dist Plot') + + # Plot! + url = py.plot(fig, filename='Distplot with rug text', validate=False) + ``` + + Example 3: Plot with normal curve and hide rug plot + ``` + import plotly.plotly as py + from plotly.figure_factory import create_distplot + import numpy as np + + x1 = np.random.randn(190) + x2 = np.random.randn(200)+1 + x3 = np.random.randn(200)-1 + x4 = np.random.randn(210)+2 + + hist_data = [x1, x2, x3, x4] + group_labels = ['2012', '2013', '2014', '2015'] + + fig = create_distplot( + hist_data, group_labels, curve_type='normal', + show_rug=False, bin_size=.4) + + url = py.plot(fig, filename='hist and normal curve', validate=False) + + Example 4: Distplot with Pandas + ``` + import plotly.plotly as py + from plotly.figure_factory import create_distplot + import numpy as np + import pandas as pd + + df = pd.DataFrame({'2012': np.random.randn(200), + '2013': np.random.randn(200)+1}) + py.iplot(create_distplot([df[c] for c in df.columns], df.columns), + filename='examples/distplot with pandas', + validate=False) + ``` + """ + if colors is None: + colors = [] + if rug_text is None: + rug_text = [] + + validate_distplot(hist_data, curve_type) + utils.validate_equal_length(hist_data, group_labels) + + if isinstance(bin_size, (float, int)): + bin_size = [bin_size] * len(hist_data) + + hist = _Distplot( + hist_data, histnorm, group_labels, bin_size, + curve_type, colors, rug_text, + show_hist, show_curve).make_hist() + + if curve_type == 'normal': + curve = _Distplot( + hist_data, histnorm, group_labels, bin_size, + curve_type, colors, rug_text, + show_hist, show_curve).make_normal() + else: + curve = _Distplot( + hist_data, histnorm, group_labels, bin_size, + curve_type, colors, rug_text, + show_hist, show_curve).make_kde() + + rug = _Distplot( + hist_data, histnorm, group_labels, bin_size, + curve_type, colors, rug_text, + show_hist, show_curve).make_rug() + + data = [] + if show_hist: + data.append(hist) + if show_curve: + data.append(curve) + if show_rug: + data.append(rug) + layout = graph_objs.Layout( + barmode='overlay', + hovermode='closest', + legend=dict(traceorder='reversed'), + xaxis1=dict(domain=[0.0, 1.0], + anchor='y2', + zeroline=False), + yaxis1=dict(domain=[0.35, 1], + anchor='free', + position=0.0), + yaxis2=dict(domain=[0, 0.25], + anchor='x1', + dtick=1, + showticklabels=False)) + else: + layout = graph_objs.Layout( + barmode='overlay', + hovermode='closest', + legend=dict(traceorder='reversed'), + xaxis1=dict(domain=[0.0, 1.0], + anchor='y2', + zeroline=False), + yaxis1=dict(domain=[0., 1], + anchor='free', + position=0.0)) + + data = sum(data, []) + return graph_objs.Figure(data=data, layout=layout) + + +class _Distplot(object): + """ + Refer to TraceFactory.create_distplot() for docstring + """ + def __init__(self, hist_data, histnorm, group_labels, + bin_size, curve_type, colors, + rug_text, show_hist, show_curve): + self.hist_data = hist_data + self.histnorm = histnorm + self.group_labels = group_labels + self.bin_size = bin_size + self.show_hist = show_hist + self.show_curve = show_curve + self.trace_number = len(hist_data) + if rug_text: + self.rug_text = rug_text + else: + self.rug_text = [None] * self.trace_number + + self.start = [] + self.end = [] + if colors: + self.colors = colors + else: + self.colors = [ + "rgb(31, 119, 180)", "rgb(255, 127, 14)", + "rgb(44, 160, 44)", "rgb(214, 39, 40)", + "rgb(148, 103, 189)", "rgb(140, 86, 75)", + "rgb(227, 119, 194)", "rgb(127, 127, 127)", + "rgb(188, 189, 34)", "rgb(23, 190, 207)"] + self.curve_x = [None] * self.trace_number + self.curve_y = [None] * self.trace_number + + for trace in self.hist_data: + self.start.append(min(trace) * 1.) + self.end.append(max(trace) * 1.) + + def make_hist(self): + """ + Makes the histogram(s) for FigureFactory.create_distplot(). + + :rtype (list) hist: list of histogram representations + """ + hist = [None] * self.trace_number + + for index in range(self.trace_number): + hist[index] = dict(type='histogram', + x=self.hist_data[index], + xaxis='x1', + yaxis='y1', + histnorm=self.histnorm, + name=self.group_labels[index], + legendgroup=self.group_labels[index], + marker=dict(color=self.colors[index]), + autobinx=False, + xbins=dict(start=self.start[index], + end=self.end[index], + size=self.bin_size[index]), + opacity=.7) + return hist + + def make_kde(self): + """ + Makes the kernel density estimation(s) for create_distplot(). + + This is called when curve_type = 'kde' in create_distplot(). + + :rtype (list) curve: list of kde representations + """ + curve = [None] * self.trace_number + for index in range(self.trace_number): + self.curve_x[index] = [self.start[index] + + x * (self.end[index] - self.start[index]) + / 500 for x in range(500)] + self.curve_y[index] = (scipy_stats.gaussian_kde + (self.hist_data[index]) + (self.curve_x[index])) + + if self.histnorm == ALTERNATIVE_HISTNORM: + self.curve_y[index] *= self.bin_size[index] + + for index in range(self.trace_number): + curve[index] = dict(type='scatter', + x=self.curve_x[index], + y=self.curve_y[index], + xaxis='x1', + yaxis='y1', + mode='lines', + name=self.group_labels[index], + legendgroup=self.group_labels[index], + showlegend=False if self.show_hist else True, + marker=dict(color=self.colors[index])) + return curve + + def make_normal(self): + """ + Makes the normal curve(s) for create_distplot(). + + This is called when curve_type = 'normal' in create_distplot(). + + :rtype (list) curve: list of normal curve representations + """ + curve = [None] * self.trace_number + mean = [None] * self.trace_number + sd = [None] * self.trace_number + + for index in range(self.trace_number): + mean[index], sd[index] = (scipy_stats.norm.fit + (self.hist_data[index])) + self.curve_x[index] = [self.start[index] + + x * (self.end[index] - self.start[index]) + / 500 for x in range(500)] + self.curve_y[index] = scipy_stats.norm.pdf( + self.curve_x[index], loc=mean[index], scale=sd[index]) + + if self.histnorm == ALTERNATIVE_HISTNORM: + self.curve_y[index] *= self.bin_size[index] + + for index in range(self.trace_number): + curve[index] = dict(type='scatter', + x=self.curve_x[index], + y=self.curve_y[index], + xaxis='x1', + yaxis='y1', + mode='lines', + name=self.group_labels[index], + legendgroup=self.group_labels[index], + showlegend=False if self.show_hist else True, + marker=dict(color=self.colors[index])) + return curve + + def make_rug(self): + """ + Makes the rug plot(s) for create_distplot(). + + :rtype (list) rug: list of rug plot representations + """ + rug = [None] * self.trace_number + for index in range(self.trace_number): + + rug[index] = dict(type='scatter', + x=self.hist_data[index], + y=([self.group_labels[index]] * + len(self.hist_data[index])), + xaxis='x1', + yaxis='y2', + mode='markers', + name=self.group_labels[index], + legendgroup=self.group_labels[index], + showlegend=(False if self.show_hist or + self.show_curve else True), + text=self.rug_text[index], + marker=dict(color=self.colors[index], + symbol='line-ns-open')) + return rug diff --git a/plotly/figure_factory/_gantt.py b/plotly/figure_factory/_gantt.py new file mode 100644 index 00000000000..c78434f0381 --- /dev/null +++ b/plotly/figure_factory/_gantt.py @@ -0,0 +1,778 @@ +from __future__ import absolute_import + +from numbers import Number + +from plotly import exceptions, optional_imports +from plotly.figure_factory import utils + +pd = optional_imports.get_module('pandas') + +REQUIRED_GANTT_KEYS = ['Task', 'Start', 'Finish'] + + +def validate_gantt(df): + """ + Validates the inputted dataframe or list + """ + if pd and isinstance(df, pd.core.frame.DataFrame): + # validate that df has all the required keys + for key in REQUIRED_GANTT_KEYS: + if key not in df: + raise exceptions.PlotlyError( + "The columns in your dataframe must include the " + "following keys: {0}".format( + ', '.join(REQUIRED_GANTT_KEYS)) + ) + + num_of_rows = len(df.index) + chart = [] + for index in range(num_of_rows): + task_dict = {} + for key in df: + task_dict[key] = df.ix[index][key] + chart.append(task_dict) + + return chart + + # validate if df is a list + if not isinstance(df, list): + raise exceptions.PlotlyError("You must input either a dataframe " + "or a list of dictionaries.") + + # validate if df is empty + if len(df) <= 0: + raise exceptions.PlotlyError("Your list is empty. It must contain " + "at least one dictionary.") + if not isinstance(df[0], dict): + raise exceptions.PlotlyError("Your list must only " + "include dictionaries.") + return df + + +def gantt(chart, colors, title, bar_width, showgrid_x, showgrid_y, height, + width, tasks=None, task_names=None, data=None, group_tasks=False): + """ + Refer to create_gantt() for docstring + """ + if tasks is None: + tasks = [] + if task_names is None: + task_names = [] + if data is None: + data = [] + + for index in range(len(chart)): + task = dict(x0=chart[index]['Start'], + x1=chart[index]['Finish'], + name=chart[index]['Task']) + if 'Description' in chart[index]: + task['description'] = chart[index]['Description'] + tasks.append(task) + + shape_template = { + 'type': 'rect', + 'xref': 'x', + 'yref': 'y', + 'opacity': 1, + 'line': { + 'width': 0, + } + } + # create the list of task names + for index in range(len(tasks)): + tn = tasks[index]['name'] + # Is added to task_names if group_tasks is set to False, + # or if the option is used (True) it only adds them if the + # name is not already in the list + if not group_tasks or tn not in task_names: + task_names.append(tn) + # Guarantees that for grouped tasks the tasks that are inserted first + # are shown at the top + if group_tasks: + task_names.reverse() + + color_index = 0 + for index in range(len(tasks)): + tn = tasks[index]['name'] + del tasks[index]['name'] + tasks[index].update(shape_template) + + # If group_tasks is True, all tasks with the same name belong + # to the same row. + groupID = index + if group_tasks: + groupID = task_names.index(tn) + tasks[index]['y0'] = groupID - bar_width + tasks[index]['y1'] = groupID + bar_width + + # check if colors need to be looped + if color_index >= len(colors): + color_index = 0 + tasks[index]['fillcolor'] = colors[color_index] + # Add a line for hover text and autorange + entry = dict( + x=[tasks[index]['x0'], tasks[index]['x1']], + y=[groupID, groupID], + name='', + marker={'color': 'white'} + ) + if "description" in tasks[index]: + entry['text'] = tasks[index]['description'] + del tasks[index]['description'] + data.append(entry) + color_index += 1 + + layout = dict( + title=title, + showlegend=False, + height=height, + width=width, + shapes=[], + hovermode='closest', + yaxis=dict( + showgrid=showgrid_y, + ticktext=task_names, + tickvals=list(range(len(task_names))), + range=[-1, len(task_names) + 1], + autorange=False, + zeroline=False, + ), + xaxis=dict( + showgrid=showgrid_x, + zeroline=False, + rangeselector=dict( + buttons=list([ + dict(count=7, + label='1w', + step='day', + stepmode='backward'), + dict(count=1, + label='1m', + step='month', + stepmode='backward'), + dict(count=6, + label='6m', + step='month', + stepmode='backward'), + dict(count=1, + label='YTD', + step='year', + stepmode='todate'), + dict(count=1, + label='1y', + step='year', + stepmode='backward'), + dict(step='all') + ]) + ), + type='date' + ) + ) + layout['shapes'] = tasks + + fig = dict(data=data, layout=layout) + return fig + + +def gantt_colorscale(chart, colors, title, index_col, show_colorbar, bar_width, + showgrid_x, showgrid_y, height, width, tasks=None, + task_names=None, data=None, group_tasks=False): + """ + Refer to FigureFactory.create_gantt() for docstring + """ + if tasks is None: + tasks = [] + if task_names is None: + task_names = [] + if data is None: + data = [] + showlegend = False + + for index in range(len(chart)): + task = dict(x0=chart[index]['Start'], + x1=chart[index]['Finish'], + name=chart[index]['Task']) + if 'Description' in chart[index]: + task['description'] = chart[index]['Description'] + tasks.append(task) + + shape_template = { + 'type': 'rect', + 'xref': 'x', + 'yref': 'y', + 'opacity': 1, + 'line': { + 'width': 0, + } + } + + # compute the color for task based on indexing column + if isinstance(chart[0][index_col], Number): + # check that colors has at least 2 colors + if len(colors) < 2: + raise exceptions.PlotlyError( + "You must use at least 2 colors in 'colors' if you " + "are using a colorscale. However only the first two " + "colors given will be used for the lower and upper " + "bounds on the colormap." + ) + + # create the list of task names + for index in range(len(tasks)): + tn = tasks[index]['name'] + # Is added to task_names if group_tasks is set to False, + # or if the option is used (True) it only adds them if the + # name is not already in the list + if not group_tasks or tn not in task_names: + task_names.append(tn) + # Guarantees that for grouped tasks the tasks that are inserted + # first are shown at the top + if group_tasks: + task_names.reverse() + + for index in range(len(tasks)): + tn = tasks[index]['name'] + del tasks[index]['name'] + tasks[index].update(shape_template) + + # If group_tasks is True, all tasks with the same name belong + # to the same row. + groupID = index + if group_tasks: + groupID = task_names.index(tn) + tasks[index]['y0'] = groupID - bar_width + tasks[index]['y1'] = groupID + bar_width + + # unlabel color + colors = utils.color_parser(colors, utils.unlabel_rgb) + lowcolor = colors[0] + highcolor = colors[1] + + intermed = (chart[index][index_col]) / 100.0 + intermed_color = utils.find_intermediate_color( + lowcolor, highcolor, intermed + ) + intermed_color = utils.color_parser( + intermed_color, utils.label_rgb + ) + tasks[index]['fillcolor'] = intermed_color + # relabel colors with 'rgb' + colors = utils.color_parser(colors, utils.label_rgb) + + # add a line for hover text and autorange + entry = dict( + x=[tasks[index]['x0'], tasks[index]['x1']], + y=[groupID, groupID], + name='', + marker={'color': 'white'} + ) + if "description" in tasks[index]: + entry['text'] = tasks[index]['description'] + del tasks[index]['description'] + data.append(entry) + + if show_colorbar is True: + # generate dummy data for colorscale visibility + data.append( + dict( + x=[tasks[index]['x0'], tasks[index]['x0']], + y=[index, index], + name='', + marker={'color': 'white', + 'colorscale': [[0, colors[0]], [1, colors[1]]], + 'showscale': True, + 'cmax': 100, + 'cmin': 0} + ) + ) + + if isinstance(chart[0][index_col], str): + index_vals = [] + for row in range(len(tasks)): + if chart[row][index_col] not in index_vals: + index_vals.append(chart[row][index_col]) + + index_vals.sort() + + if len(colors) < len(index_vals): + raise exceptions.PlotlyError( + "Error. The number of colors in 'colors' must be no less " + "than the number of unique index values in your group " + "column." + ) + + # make a dictionary assignment to each index value + index_vals_dict = {} + # define color index + c_index = 0 + for key in index_vals: + if c_index > len(colors) - 1: + c_index = 0 + index_vals_dict[key] = colors[c_index] + c_index += 1 + + # create the list of task names + for index in range(len(tasks)): + tn = tasks[index]['name'] + # Is added to task_names if group_tasks is set to False, + # or if the option is used (True) it only adds them if the + # name is not already in the list + if not group_tasks or tn not in task_names: + task_names.append(tn) + # Guarantees that for grouped tasks the tasks that are inserted + # first are shown at the top + if group_tasks: + task_names.reverse() + + for index in range(len(tasks)): + tn = tasks[index]['name'] + del tasks[index]['name'] + tasks[index].update(shape_template) + # If group_tasks is True, all tasks with the same name belong + # to the same row. + groupID = index + if group_tasks: + groupID = task_names.index(tn) + tasks[index]['y0'] = groupID - bar_width + tasks[index]['y1'] = groupID + bar_width + + tasks[index]['fillcolor'] = index_vals_dict[ + chart[index][index_col] + ] + + # add a line for hover text and autorange + entry = dict( + x=[tasks[index]['x0'], tasks[index]['x1']], + y=[groupID, groupID], + name='', + marker={'color': 'white'} + ) + if "description" in tasks[index]: + entry['text'] = tasks[index]['description'] + del tasks[index]['description'] + data.append(entry) + + if show_colorbar is True: + # generate dummy data to generate legend + showlegend = True + for k, index_value in enumerate(index_vals): + data.append( + dict( + x=[tasks[index]['x0'], tasks[index]['x0']], + y=[k, k], + showlegend=True, + name=str(index_value), + hoverinfo='none', + marker=dict( + color=colors[k], + size=1 + ) + ) + ) + + layout = dict( + title=title, + showlegend=showlegend, + height=height, + width=width, + shapes=[], + hovermode='closest', + yaxis=dict( + showgrid=showgrid_y, + ticktext=task_names, + tickvals=list(range(len(task_names))), + range=[-1, len(task_names) + 1], + autorange=False, + zeroline=False, + ), + xaxis=dict( + showgrid=showgrid_x, + zeroline=False, + rangeselector=dict( + buttons=list([ + dict(count=7, + label='1w', + step='day', + stepmode='backward'), + dict(count=1, + label='1m', + step='month', + stepmode='backward'), + dict(count=6, + label='6m', + step='month', + stepmode='backward'), + dict(count=1, + label='YTD', + step='year', + stepmode='todate'), + dict(count=1, + label='1y', + step='year', + stepmode='backward'), + dict(step='all') + ]) + ), + type='date' + ) + ) + layout['shapes'] = tasks + + fig = dict(data=data, layout=layout) + return fig + + +def gantt_dict(chart, colors, title, index_col, show_colorbar, bar_width, + showgrid_x, showgrid_y, height, width, tasks=None, + task_names=None, data=None, group_tasks=False): + """ + Refer to FigureFactory.create_gantt() for docstring + """ + if tasks is None: + tasks = [] + if task_names is None: + task_names = [] + if data is None: + data = [] + showlegend = False + + for index in range(len(chart)): + task = dict(x0=chart[index]['Start'], + x1=chart[index]['Finish'], + name=chart[index]['Task']) + if 'Description' in chart[index]: + task['description'] = chart[index]['Description'] + tasks.append(task) + + shape_template = { + 'type': 'rect', + 'xref': 'x', + 'yref': 'y', + 'opacity': 1, + 'line': { + 'width': 0, + } + } + + index_vals = [] + for row in range(len(tasks)): + if chart[row][index_col] not in index_vals: + index_vals.append(chart[row][index_col]) + + index_vals.sort() + + # verify each value in index column appears in colors dictionary + for key in index_vals: + if key not in colors: + raise exceptions.PlotlyError( + "If you are using colors as a dictionary, all of its " + "keys must be all the values in the index column." + ) + + # create the list of task names + for index in range(len(tasks)): + tn = tasks[index]['name'] + # Is added to task_names if group_tasks is set to False, + # or if the option is used (True) it only adds them if the + # name is not already in the list + if not group_tasks or tn not in task_names: + task_names.append(tn) + # Guarantees that for grouped tasks the tasks that are inserted first + # are shown at the top + if group_tasks: + task_names.reverse() + + for index in range(len(tasks)): + tn = tasks[index]['name'] + del tasks[index]['name'] + tasks[index].update(shape_template) + + # If group_tasks is True, all tasks with the same name belong + # to the same row. + groupID = index + if group_tasks: + groupID = task_names.index(tn) + tasks[index]['y0'] = groupID - bar_width + tasks[index]['y1'] = groupID + bar_width + + tasks[index]['fillcolor'] = colors[chart[index][index_col]] + + # add a line for hover text and autorange + entry = dict( + x=[tasks[index]['x0'], tasks[index]['x1']], + y=[groupID, groupID], + name='', + marker={'color': 'white'} + ) + if "description" in tasks[index]: + entry['text'] = tasks[index]['description'] + del tasks[index]['description'] + data.append(entry) + + if show_colorbar is True: + # generate dummy data to generate legend + showlegend = True + for k, index_value in enumerate(index_vals): + data.append( + dict( + x=[tasks[index]['x0'], tasks[index]['x0']], + y=[k, k], + showlegend=True, + hoverinfo='none', + name=str(index_value), + marker=dict( + color=colors[index_value], + size=1 + ) + ) + ) + + layout = dict( + title=title, + showlegend=showlegend, + height=height, + width=width, + shapes=[], + hovermode='closest', + yaxis=dict( + showgrid=showgrid_y, + ticktext=task_names, + tickvals=list(range(len(task_names))), + range=[-1, len(task_names) + 1], + autorange=False, + zeroline=False, + ), + xaxis=dict( + showgrid=showgrid_x, + zeroline=False, + rangeselector=dict( + buttons=list([ + dict(count=7, + label='1w', + step='day', + stepmode='backward'), + dict(count=1, + label='1m', + step='month', + stepmode='backward'), + dict(count=6, + label='6m', + step='month', + stepmode='backward'), + dict(count=1, + label='YTD', + step='year', + stepmode='todate'), + dict(count=1, + label='1y', + step='year', + stepmode='backward'), + dict(step='all') + ]) + ), + type='date' + ) + ) + layout['shapes'] = tasks + + fig = dict(data=data, layout=layout) + return fig + + +def create_gantt(df, colors=None, index_col=None, show_colorbar=False, + reverse_colors=False, title='Gantt Chart', bar_width=0.2, + showgrid_x=False, showgrid_y=False, height=600, width=900, + tasks=None, task_names=None, data=None, group_tasks=False): + """ + Returns figure for a gantt chart + + :param (array|list) df: input data for gantt chart. Must be either a + a dataframe or a list. If dataframe, the columns must include + 'Task', 'Start' and 'Finish'. Other columns can be included and + used for indexing. If a list, its elements must be dictionaries + with the same required column headers: 'Task', 'Start' and + 'Finish'. + :param (str|list|dict|tuple) colors: either a plotly scale name, an + rgb or hex color, a color tuple or a list of colors. An rgb color + is of the form 'rgb(x, y, z)' where x, y, z belong to the interval + [0, 255] and a color tuple is a tuple of the form (a, b, c) where + a, b and c belong to [0, 1]. If colors is a list, it must + contain the valid color types aforementioned as its members. + If a dictionary, all values of the indexing column must be keys in + colors. + :param (str|float) index_col: the column header (if df is a data + frame) that will function as the indexing column. If df is a list, + index_col must be one of the keys in all the items of df. + :param (bool) show_colorbar: determines if colorbar will be visible. + Only applies if values in the index column are numeric. + :param (bool) reverse_colors: reverses the order of selected colors + :param (str) title: the title of the chart + :param (float) bar_width: the width of the horizontal bars in the plot + :param (bool) showgrid_x: show/hide the x-axis grid + :param (bool) showgrid_y: show/hide the y-axis grid + :param (float) height: the height of the chart + :param (float) width: the width of the chart + + Example 1: Simple Gantt Chart + ``` + import plotly.plotly as py + from plotly.figure_factory import create_gantt + + # Make data for chart + df = [dict(Task="Job A", Start='2009-01-01', Finish='2009-02-30'), + dict(Task="Job B", Start='2009-03-05', Finish='2009-04-15'), + dict(Task="Job C", Start='2009-02-20', Finish='2009-05-30')] + + # Create a figure + fig = create_gantt(df) + + # Plot the data + py.iplot(fig, filename='Simple Gantt Chart', world_readable=True) + ``` + + Example 2: Index by Column with Numerical Entries + ``` + import plotly.plotly as py + from plotly.figure_factory import create_gantt + + # Make data for chart + df = [dict(Task="Job A", Start='2009-01-01', + Finish='2009-02-30', Complete=10), + dict(Task="Job B", Start='2009-03-05', + Finish='2009-04-15', Complete=60), + dict(Task="Job C", Start='2009-02-20', + Finish='2009-05-30', Complete=95)] + + # Create a figure with Plotly colorscale + fig = create_gantt(df, colors='Blues', index_col='Complete', + show_colorbar=True, bar_width=0.5, + showgrid_x=True, showgrid_y=True) + + # Plot the data + py.iplot(fig, filename='Numerical Entries', world_readable=True) + ``` + + Example 3: Index by Column with String Entries + ``` + import plotly.plotly as py + from plotly.figure_factory import create_gantt + + # Make data for chart + df = [dict(Task="Job A", Start='2009-01-01', + Finish='2009-02-30', Resource='Apple'), + dict(Task="Job B", Start='2009-03-05', + Finish='2009-04-15', Resource='Grape'), + dict(Task="Job C", Start='2009-02-20', + Finish='2009-05-30', Resource='Banana')] + + # Create a figure with Plotly colorscale + fig = create_gantt(df, colors=['rgb(200, 50, 25)', (1, 0, 1), '#6c4774'], + index_col='Resource', reverse_colors=True, + show_colorbar=True) + + # Plot the data + py.iplot(fig, filename='String Entries', world_readable=True) + ``` + + Example 4: Use a dictionary for colors + ``` + import plotly.plotly as py + from plotly.figure_factory import create_gantt + + # Make data for chart + df = [dict(Task="Job A", Start='2009-01-01', + Finish='2009-02-30', Resource='Apple'), + dict(Task="Job B", Start='2009-03-05', + Finish='2009-04-15', Resource='Grape'), + dict(Task="Job C", Start='2009-02-20', + Finish='2009-05-30', Resource='Banana')] + + # Make a dictionary of colors + colors = {'Apple': 'rgb(255, 0, 0)', + 'Grape': 'rgb(170, 14, 200)', + 'Banana': (1, 1, 0.2)} + + # Create a figure with Plotly colorscale + fig = create_gantt(df, colors=colors, index_col='Resource', + show_colorbar=True) + + # Plot the data + py.iplot(fig, filename='dictioanry colors', world_readable=True) + ``` + + Example 5: Use a pandas dataframe + ``` + import plotly.plotly as py + from plotly.figure_factory import create_gantt + + import pandas as pd + + # Make data as a dataframe + df = pd.DataFrame([['Run', '2010-01-01', '2011-02-02', 10], + ['Fast', '2011-01-01', '2012-06-05', 55], + ['Eat', '2012-01-05', '2013-07-05', 94]], + columns=['Task', 'Start', 'Finish', 'Complete']) + + # Create a figure with Plotly colorscale + fig = create_gantt(df, colors='Blues', index_col='Complete', + show_colorbar=True, bar_width=0.5, + showgrid_x=True, showgrid_y=True) + + # Plot the data + py.iplot(fig, filename='data with dataframe', world_readable=True) + ``` + """ + # validate gantt input data + chart = validate_gantt(df) + + if index_col: + if index_col not in chart[0]: + raise exceptions.PlotlyError( + "In order to use an indexing column and assign colors to " + "the values of the index, you must choose an actual " + "column name in the dataframe or key if a list of " + "dictionaries is being used.") + + # validate gantt index column + index_list = [] + for dictionary in chart: + index_list.append(dictionary[index_col]) + utils.validate_index(index_list) + + # Validate colors + if isinstance(colors, dict): + colors = utils.validate_colors_dict(colors, 'rgb') + else: + colors = utils.validate_colors(colors, 'rgb') + + if reverse_colors is True: + colors.reverse() + + if not index_col: + if isinstance(colors, dict): + raise exceptions.PlotlyError( + "Error. You have set colors to a dictionary but have not " + "picked an index. An index is required if you are " + "assigning colors to particular values in a dictioanry." + ) + fig = gantt( + chart, colors, title, bar_width, showgrid_x, showgrid_y, + height, width, tasks=None, task_names=None, data=None, + group_tasks=group_tasks + ) + return fig + else: + if not isinstance(colors, dict): + fig = gantt_colorscale( + chart, colors, title, index_col, show_colorbar, bar_width, + showgrid_x, showgrid_y, height, width, + tasks=None, task_names=None, data=None, group_tasks=group_tasks + ) + return fig + else: + fig = gantt_dict( + chart, colors, title, index_col, show_colorbar, bar_width, + showgrid_x, showgrid_y, height, width, + tasks=None, task_names=None, data=None, group_tasks=group_tasks + ) + return fig diff --git a/plotly/figure_factory/_ohlc.py b/plotly/figure_factory/_ohlc.py new file mode 100644 index 00000000000..b5e84cd6d93 --- /dev/null +++ b/plotly/figure_factory/_ohlc.py @@ -0,0 +1,380 @@ +from __future__ import absolute_import + +from plotly import exceptions +from plotly.graph_objs import graph_objs +from plotly.figure_factory import utils + + +# Default colours for finance charts +_DEFAULT_INCREASING_COLOR = '#3D9970' # http://clrs.cc +_DEFAULT_DECREASING_COLOR = '#FF4136' + + +def validate_ohlc(open, high, low, close, direction, **kwargs): + """ + ohlc and candlestick specific validations + + Specifically, this checks that the high value is the greatest value and + the low value is the lowest value in each unit. + + See FigureFactory.create_ohlc() or FigureFactory.create_candlestick() + for params + + :raises: (PlotlyError) If the high value is not the greatest value in + each unit. + :raises: (PlotlyError) If the low value is not the lowest value in each + unit. + :raises: (PlotlyError) If direction is not 'increasing' or 'decreasing' + """ + for lst in [open, low, close]: + for index in range(len(high)): + if high[index] < lst[index]: + raise exceptions.PlotlyError("Oops! Looks like some of " + "your high values are less " + "the corresponding open, " + "low, or close values. " + "Double check that your data " + "is entered in O-H-L-C order") + + for lst in [open, high, close]: + for index in range(len(low)): + if low[index] > lst[index]: + raise exceptions.PlotlyError("Oops! Looks like some of " + "your low values are greater " + "than the corresponding high" + ", open, or close values. " + "Double check that your data " + "is entered in O-H-L-C order") + + direction_opts = ('increasing', 'decreasing', 'both') + if direction not in direction_opts: + raise exceptions.PlotlyError("direction must be defined as " + "'increasing', 'decreasing', or " + "'both'") + + +def make_increasing_ohlc(open, high, low, close, dates, **kwargs): + """ + Makes increasing ohlc sticks + + _make_increasing_ohlc() and _make_decreasing_ohlc separate the + increasing trace from the decreasing trace so kwargs (such as + color) can be passed separately to increasing or decreasing traces + when direction is set to 'increasing' or 'decreasing' in + FigureFactory.create_candlestick() + + :param (list) open: opening values + :param (list) high: high values + :param (list) low: low values + :param (list) close: closing values + :param (list) dates: list of datetime objects. Default: None + :param kwargs: kwargs to be passed to increasing trace via + plotly.graph_objs.Scatter. + + :rtype (trace) ohlc_incr_data: Scatter trace of all increasing ohlc + sticks. + """ + (flat_increase_x, + flat_increase_y, + text_increase) = _OHLC(open, high, low, close, dates).get_increase() + + if 'name' in kwargs: + showlegend = True + else: + kwargs.setdefault('name', 'Increasing') + showlegend = False + + kwargs.setdefault('line', dict(color=_DEFAULT_INCREASING_COLOR, + width=1)) + kwargs.setdefault('text', text_increase) + + ohlc_incr = dict(type='scatter', + x=flat_increase_x, + y=flat_increase_y, + mode='lines', + showlegend=showlegend, + **kwargs) + return ohlc_incr + + +def make_decreasing_ohlc(open, high, low, close, dates, **kwargs): + """ + Makes decreasing ohlc sticks + + :param (list) open: opening values + :param (list) high: high values + :param (list) low: low values + :param (list) close: closing values + :param (list) dates: list of datetime objects. Default: None + :param kwargs: kwargs to be passed to increasing trace via + plotly.graph_objs.Scatter. + + :rtype (trace) ohlc_decr_data: Scatter trace of all decreasing ohlc + sticks. + """ + (flat_decrease_x, + flat_decrease_y, + text_decrease) = _OHLC(open, high, low, close, dates).get_decrease() + + kwargs.setdefault('line', dict(color=_DEFAULT_DECREASING_COLOR, + width=1)) + kwargs.setdefault('text', text_decrease) + kwargs.setdefault('showlegend', False) + kwargs.setdefault('name', 'Decreasing') + + ohlc_decr = dict(type='scatter', + x=flat_decrease_x, + y=flat_decrease_y, + mode='lines', + **kwargs) + return ohlc_decr + + +def create_ohlc(open, high, low, close, dates=None, direction='both', + **kwargs): + """ + BETA function that creates an ohlc chart + + :param (list) open: opening values + :param (list) high: high values + :param (list) low: low values + :param (list) close: closing + :param (list) dates: list of datetime objects. Default: None + :param (string) direction: direction can be 'increasing', 'decreasing', + or 'both'. When the direction is 'increasing', the returned figure + consists of all units where the close value is greater than the + corresponding open value, and when the direction is 'decreasing', + the returned figure consists of all units where the close value is + less than or equal to the corresponding open value. When the + direction is 'both', both increasing and decreasing units are + returned. Default: 'both' + :param kwargs: kwargs passed through plotly.graph_objs.Scatter. + These kwargs describe other attributes about the ohlc Scatter trace + such as the color or the legend name. For more information on valid + kwargs call help(plotly.graph_objs.Scatter) + + :rtype (dict): returns a representation of an ohlc chart figure. + + Example 1: Simple OHLC chart from a Pandas DataFrame + ``` + import plotly.plotly as py + from plotly.figure_factory import create_ohlc + from datetime import datetime + + import pandas.io.data as web + + df = web.DataReader("aapl", 'yahoo', datetime(2008, 8, 15), + datetime(2008, 10, 15)) + fig = create_ohlc(df.Open, df.High, df.Low, df.Close, dates=df.index) + + py.plot(fig, filename='finance/aapl-ohlc') + ``` + + Example 2: Add text and annotations to the OHLC chart + ``` + import plotly.plotly as py + from plotly.figure_factory import create_ohlc + from datetime import datetime + + import pandas.io.data as web + + df = web.DataReader("aapl", 'yahoo', datetime(2008, 8, 15), + datetime(2008, 10, 15)) + fig = create_ohlc(df.Open, df.High, df.Low, df.Close, dates=df.index) + + # Update the fig - options here: https://plot.ly/python/reference/#Layout + fig['layout'].update({ + 'title': 'The Great Recession', + 'yaxis': {'title': 'AAPL Stock'}, + 'shapes': [{ + 'x0': '2008-09-15', 'x1': '2008-09-15', 'type': 'line', + 'y0': 0, 'y1': 1, 'xref': 'x', 'yref': 'paper', + 'line': {'color': 'rgb(40,40,40)', 'width': 0.5} + }], + 'annotations': [{ + 'text': "the fall of Lehman Brothers", + 'x': '2008-09-15', 'y': 1.02, + 'xref': 'x', 'yref': 'paper', + 'showarrow': False, 'xanchor': 'left' + }] + }) + + py.plot(fig, filename='finance/aapl-recession-ohlc', validate=False) + ``` + + Example 3: Customize the OHLC colors + ``` + import plotly.plotly as py + from plotly.figure_factory import create_ohlc + from plotly.graph_objs import Line, Marker + from datetime import datetime + + import pandas.io.data as web + + df = web.DataReader("aapl", 'yahoo', datetime(2008, 1, 1), + datetime(2009, 4, 1)) + + # Make increasing ohlc sticks and customize their color and name + fig_increasing = create_ohlc(df.Open, df.High, df.Low, df.Close, + dates=df.index, direction='increasing', + name='AAPL', + line=Line(color='rgb(150, 200, 250)')) + + # Make decreasing ohlc sticks and customize their color and name + fig_decreasing = create_ohlc(df.Open, df.High, df.Low, df.Close, + dates=df.index, direction='decreasing', + line=Line(color='rgb(128, 128, 128)')) + + # Initialize the figure + fig = fig_increasing + + # Add decreasing data with .extend() + fig['data'].extend(fig_decreasing['data']) + + py.iplot(fig, filename='finance/aapl-ohlc-colors', validate=False) + ``` + + Example 4: OHLC chart with datetime objects + ``` + import plotly.plotly as py + from plotly.figure_factory import create_ohlc + + from datetime import datetime + + # Add data + open_data = [33.0, 33.3, 33.5, 33.0, 34.1] + high_data = [33.1, 33.3, 33.6, 33.2, 34.8] + low_data = [32.7, 32.7, 32.8, 32.6, 32.8] + close_data = [33.0, 32.9, 33.3, 33.1, 33.1] + dates = [datetime(year=2013, month=10, day=10), + datetime(year=2013, month=11, day=10), + datetime(year=2013, month=12, day=10), + datetime(year=2014, month=1, day=10), + datetime(year=2014, month=2, day=10)] + + # Create ohlc + fig = create_ohlc(open_data, high_data, low_data, close_data, dates=dates) + + py.iplot(fig, filename='finance/simple-ohlc', validate=False) + ``` + """ + if dates is not None: + utils.validate_equal_length(open, high, low, close, dates) + else: + utils.validate_equal_length(open, high, low, close) + validate_ohlc(open, high, low, close, direction, **kwargs) + + if direction is 'increasing': + ohlc_incr = make_increasing_ohlc(open, high, low, close, dates, + **kwargs) + data = [ohlc_incr] + elif direction is 'decreasing': + ohlc_decr = make_decreasing_ohlc(open, high, low, close, dates, + **kwargs) + data = [ohlc_decr] + else: + ohlc_incr = make_increasing_ohlc(open, high, low, close, dates, + **kwargs) + ohlc_decr = make_decreasing_ohlc(open, high, low, close, dates, + **kwargs) + data = [ohlc_incr, ohlc_decr] + + layout = graph_objs.Layout(xaxis=dict(zeroline=False), + hovermode='closest') + + return graph_objs.Figure(data=data, layout=layout) + + +class _OHLC(object): + """ + Refer to FigureFactory.create_ohlc_increase() for docstring. + """ + def __init__(self, open, high, low, close, dates, **kwargs): + self.open = open + self.high = high + self.low = low + self.close = close + self.empty = [None] * len(open) + self.dates = dates + + self.all_x = [] + self.all_y = [] + self.increase_x = [] + self.increase_y = [] + self.decrease_x = [] + self.decrease_y = [] + self.get_all_xy() + self.separate_increase_decrease() + + def get_all_xy(self): + """ + Zip data to create OHLC shape + + OHLC shape: low to high vertical bar with + horizontal branches for open and close values. + If dates were added, the smallest date difference is calculated and + multiplied by .2 to get the length of the open and close branches. + If no date data was provided, the x-axis is a list of integers and the + length of the open and close branches is .2. + """ + self.all_y = list(zip(self.open, self.open, self.high, + self.low, self.close, self.close, self.empty)) + if self.dates is not None: + date_dif = [] + for i in range(len(self.dates) - 1): + date_dif.append(self.dates[i + 1] - self.dates[i]) + date_dif_min = (min(date_dif)) / 5 + self.all_x = [[x - date_dif_min, x, x, x, x, x + + date_dif_min, None] for x in self.dates] + else: + self.all_x = [[x - .2, x, x, x, x, x + .2, None] + for x in range(len(self.open))] + + def separate_increase_decrease(self): + """ + Separate data into two groups: increase and decrease + + (1) Increase, where close > open and + (2) Decrease, where close <= open + """ + for index in range(len(self.open)): + if self.close[index] is None: + pass + elif self.close[index] > self.open[index]: + self.increase_x.append(self.all_x[index]) + self.increase_y.append(self.all_y[index]) + else: + self.decrease_x.append(self.all_x[index]) + self.decrease_y.append(self.all_y[index]) + + def get_increase(self): + """ + Flatten increase data and get increase text + + :rtype (list, list, list): flat_increase_x: x-values for the increasing + trace, flat_increase_y: y=values for the increasing trace and + text_increase: hovertext for the increasing trace + """ + flat_increase_x = utils.flatten(self.increase_x) + flat_increase_y = utils.flatten(self.increase_y) + text_increase = (("Open", "Open", "High", + "Low", "Close", "Close", '') + * (len(self.increase_x))) + + return flat_increase_x, flat_increase_y, text_increase + + def get_decrease(self): + """ + Flatten decrease data and get decrease text + + :rtype (list, list, list): flat_decrease_x: x-values for the decreasing + trace, flat_decrease_y: y=values for the decreasing trace and + text_decrease: hovertext for the decreasing trace + """ + flat_decrease_x = utils.flatten(self.decrease_x) + flat_decrease_y = utils.flatten(self.decrease_y) + text_decrease = (("Open", "Open", "High", + "Low", "Close", "Close", '') + * (len(self.decrease_x))) + + return flat_decrease_x, flat_decrease_y, text_decrease diff --git a/plotly/figure_factory/_quiver.py b/plotly/figure_factory/_quiver.py new file mode 100644 index 00000000000..8d0de352baf --- /dev/null +++ b/plotly/figure_factory/_quiver.py @@ -0,0 +1,243 @@ +from __future__ import absolute_import + +import math + +from plotly import exceptions +from plotly.graph_objs import graph_objs +from plotly.figure_factory import utils + + +def create_quiver(x, y, u, v, scale=.1, arrow_scale=.3, + angle=math.pi / 9, **kwargs): + """ + Returns data for a quiver plot. + + :param (list|ndarray) x: x coordinates of the arrow locations + :param (list|ndarray) y: y coordinates of the arrow locations + :param (list|ndarray) u: x components of the arrow vectors + :param (list|ndarray) v: y components of the arrow vectors + :param (float in [0,1]) scale: scales size of the arrows(ideally to + avoid overlap). Default = .1 + :param (float in [0,1]) arrow_scale: value multiplied to length of barb + to get length of arrowhead. Default = .3 + :param (angle in radians) angle: angle of arrowhead. Default = pi/9 + :param kwargs: kwargs passed through plotly.graph_objs.Scatter + for more information on valid kwargs call + help(plotly.graph_objs.Scatter) + + :rtype (dict): returns a representation of quiver figure. + + Example 1: Trivial Quiver + ``` + import plotly.plotly as py + from plotly.figure_factory import create_quiver + + import math + + # 1 Arrow from (0,0) to (1,1) + fig = create_quiver(x=[0], y=[0], u=[1], v=[1], scale=1) + + py.plot(fig, filename='quiver') + ``` + + Example 2: Quiver plot using meshgrid + ``` + import plotly.plotly as py + from plotly.figure_factory import create_quiver + + import numpy as np + import math + + # Add data + x,y = np.meshgrid(np.arange(0, 2, .2), np.arange(0, 2, .2)) + u = np.cos(x)*y + v = np.sin(x)*y + + #Create quiver + fig = create_quiver(x, y, u, v) + + # Plot + py.plot(fig, filename='quiver') + ``` + + Example 3: Styling the quiver plot + ``` + import plotly.plotly as py + from plotly.figure_factory import create_quiver + import numpy as np + import math + + # Add data + x, y = np.meshgrid(np.arange(-np.pi, math.pi, .5), + np.arange(-math.pi, math.pi, .5)) + u = np.cos(x)*y + v = np.sin(x)*y + + # Create quiver + fig = create_quiver(x, y, u, v, scale=.2, arrow_scale=.3, angle=math.pi/6, + name='Wind Velocity', line=Line(width=1)) + + # Add title to layout + fig['layout'].update(title='Quiver Plot') + + # Plot + py.plot(fig, filename='quiver') + ``` + """ + utils.validate_equal_length(x, y, u, v) + utils.validate_positive_scalars(arrow_scale=arrow_scale, scale=scale) + + barb_x, barb_y = _Quiver(x, y, u, v, scale, + arrow_scale, angle).get_barbs() + arrow_x, arrow_y = _Quiver(x, y, u, v, scale, + arrow_scale, angle).get_quiver_arrows() + quiver = graph_objs.Scatter(x=barb_x + arrow_x, + y=barb_y + arrow_y, + mode='lines', **kwargs) + + data = [quiver] + layout = graph_objs.Layout(hovermode='closest') + + return graph_objs.Figure(data=data, layout=layout) + + +class _Quiver(object): + """ + Refer to FigureFactory.create_quiver() for docstring + """ + def __init__(self, x, y, u, v, + scale, arrow_scale, angle, **kwargs): + try: + x = utils.flatten(x) + except exceptions.PlotlyError: + pass + + try: + y = utils.flatten(y) + except exceptions.PlotlyError: + pass + + try: + u = utils.flatten(u) + except exceptions.PlotlyError: + pass + + try: + v = utils.flatten(v) + except exceptions.PlotlyError: + pass + + self.x = x + self.y = y + self.u = u + self.v = v + self.scale = scale + self.arrow_scale = arrow_scale + self.angle = angle + self.end_x = [] + self.end_y = [] + self.scale_uv() + barb_x, barb_y = self.get_barbs() + arrow_x, arrow_y = self.get_quiver_arrows() + + def scale_uv(self): + """ + Scales u and v to avoid overlap of the arrows. + + u and v are added to x and y to get the + endpoints of the arrows so a smaller scale value will + result in less overlap of arrows. + """ + self.u = [i * self.scale for i in self.u] + self.v = [i * self.scale for i in self.v] + + def get_barbs(self): + """ + Creates x and y startpoint and endpoint pairs + + After finding the endpoint of each barb this zips startpoint and + endpoint pairs to create 2 lists: x_values for barbs and y values + for barbs + + :rtype: (list, list) barb_x, barb_y: list of startpoint and endpoint + x_value pairs separated by a None to create the barb of the arrow, + and list of startpoint and endpoint y_value pairs separated by a + None to create the barb of the arrow. + """ + self.end_x = [i + j for i, j in zip(self.x, self.u)] + self.end_y = [i + j for i, j in zip(self.y, self.v)] + empty = [None] * len(self.x) + barb_x = utils.flatten(zip(self.x, self.end_x, empty)) + barb_y = utils.flatten(zip(self.y, self.end_y, empty)) + return barb_x, barb_y + + def get_quiver_arrows(self): + """ + Creates lists of x and y values to plot the arrows + + Gets length of each barb then calculates the length of each side of + the arrow. Gets angle of barb and applies angle to each side of the + arrowhead. Next uses arrow_scale to scale the length of arrowhead and + creates x and y values for arrowhead point1 and point2. Finally x and y + values for point1, endpoint and point2s for each arrowhead are + separated by a None and zipped to create lists of x and y values for + the arrows. + + :rtype: (list, list) arrow_x, arrow_y: list of point1, endpoint, point2 + x_values separated by a None to create the arrowhead and list of + point1, endpoint, point2 y_values separated by a None to create + the barb of the arrow. + """ + dif_x = [i - j for i, j in zip(self.end_x, self.x)] + dif_y = [i - j for i, j in zip(self.end_y, self.y)] + + # Get barb lengths(default arrow length = 30% barb length) + barb_len = [None] * len(self.x) + for index in range(len(barb_len)): + barb_len[index] = math.hypot(dif_x[index], dif_y[index]) + + # Make arrow lengths + arrow_len = [None] * len(self.x) + arrow_len = [i * self.arrow_scale for i in barb_len] + + # Get barb angles + barb_ang = [None] * len(self.x) + for index in range(len(barb_ang)): + barb_ang[index] = math.atan2(dif_y[index], dif_x[index]) + + # Set angles to create arrow + ang1 = [i + self.angle for i in barb_ang] + ang2 = [i - self.angle for i in barb_ang] + + cos_ang1 = [None] * len(ang1) + for index in range(len(ang1)): + cos_ang1[index] = math.cos(ang1[index]) + seg1_x = [i * j for i, j in zip(arrow_len, cos_ang1)] + + sin_ang1 = [None] * len(ang1) + for index in range(len(ang1)): + sin_ang1[index] = math.sin(ang1[index]) + seg1_y = [i * j for i, j in zip(arrow_len, sin_ang1)] + + cos_ang2 = [None] * len(ang2) + for index in range(len(ang2)): + cos_ang2[index] = math.cos(ang2[index]) + seg2_x = [i * j for i, j in zip(arrow_len, cos_ang2)] + + sin_ang2 = [None] * len(ang2) + for index in range(len(ang2)): + sin_ang2[index] = math.sin(ang2[index]) + seg2_y = [i * j for i, j in zip(arrow_len, sin_ang2)] + + # Set coordinates to create arrow + for index in range(len(self.end_x)): + point1_x = [i - j for i, j in zip(self.end_x, seg1_x)] + point1_y = [i - j for i, j in zip(self.end_y, seg1_y)] + point2_x = [i - j for i, j in zip(self.end_x, seg2_x)] + point2_y = [i - j for i, j in zip(self.end_y, seg2_y)] + + # Combine lists to create arrow + empty = [None] * len(self.end_x) + arrow_x = utils.flatten(zip(point1_x, self.end_x, point2_x, empty)) + arrow_y = utils.flatten(zip(point1_y, self.end_y, point2_y, empty)) + return arrow_x, arrow_y diff --git a/plotly/figure_factory/_scatterplot.py b/plotly/figure_factory/_scatterplot.py new file mode 100644 index 00000000000..e9926b26f86 --- /dev/null +++ b/plotly/figure_factory/_scatterplot.py @@ -0,0 +1,1136 @@ +from __future__ import absolute_import + +from plotly import exceptions, optional_imports +from plotly.figure_factory import utils +from plotly.graph_objs import graph_objs +from plotly.tools import make_subplots + +pd = optional_imports.get_module('pandas') + +DIAG_CHOICES = ['scatter', 'histogram', 'box'] +VALID_COLORMAP_TYPES = ['cat', 'seq'] + + +def endpts_to_intervals(endpts): + """ + Returns a list of intervals for categorical colormaps + + Accepts a list or tuple of sequentially increasing numbers and returns + a list representation of the mathematical intervals with these numbers + as endpoints. For example, [1, 6] returns [[-inf, 1], [1, 6], [6, inf]] + + :raises: (PlotlyError) If input is not a list or tuple + :raises: (PlotlyError) If the input contains a string + :raises: (PlotlyError) If any number does not increase after the + previous one in the sequence + """ + length = len(endpts) + # Check if endpts is a list or tuple + if not (isinstance(endpts, (tuple)) or isinstance(endpts, (list))): + raise exceptions.PlotlyError("The intervals_endpts argument must " + "be a list or tuple of a sequence " + "of increasing numbers.") + # Check if endpts contains only numbers + for item in endpts: + if isinstance(item, str): + raise exceptions.PlotlyError("The intervals_endpts argument " + "must be a list or tuple of a " + "sequence of increasing " + "numbers.") + # Check if numbers in endpts are increasing + for k in range(length - 1): + if endpts[k] >= endpts[k + 1]: + raise exceptions.PlotlyError("The intervals_endpts argument " + "must be a list or tuple of a " + "sequence of increasing " + "numbers.") + else: + intervals = [] + # add -inf to intervals + intervals.append([float('-inf'), endpts[0]]) + for k in range(length - 1): + interval = [] + interval.append(endpts[k]) + interval.append(endpts[k + 1]) + intervals.append(interval) + # add +inf to intervals + intervals.append([endpts[length - 1], float('inf')]) + return intervals + + +def hide_tick_labels_from_box_subplots(fig): + """ + Hides tick labels for box plots in scatterplotmatrix subplots. + """ + boxplot_xaxes = [] + for trace in fig['data']: + if trace['type'] == 'box': + # stores the xaxes which correspond to boxplot subplots + # since we use xaxis1, xaxis2, etc, in plotly.py + boxplot_xaxes.append( + 'xaxis{}'.format(trace['xaxis'][1:]) + ) + for xaxis in boxplot_xaxes: + fig['layout'][xaxis]['showticklabels'] = False + + +def validate_scatterplotmatrix(df, index, diag, colormap_type, **kwargs): + """ + Validates basic inputs for FigureFactory.create_scatterplotmatrix() + + :raises: (PlotlyError) If pandas is not imported + :raises: (PlotlyError) If pandas dataframe is not inputted + :raises: (PlotlyError) If pandas dataframe has <= 1 columns + :raises: (PlotlyError) If diagonal plot choice (diag) is not one of + the viable options + :raises: (PlotlyError) If colormap_type is not a valid choice + :raises: (PlotlyError) If kwargs contains 'size', 'color' or + 'colorscale' + """ + if not pd: + raise ImportError("FigureFactory.scatterplotmatrix requires " + "a pandas DataFrame.") + + # Check if pandas dataframe + if not isinstance(df, pd.core.frame.DataFrame): + raise exceptions.PlotlyError("Dataframe not inputed. Please " + "use a pandas dataframe to pro" + "duce a scatterplot matrix.") + + # Check if dataframe is 1 column or less + if len(df.columns) <= 1: + raise exceptions.PlotlyError("Dataframe has only one column. To " + "use the scatterplot matrix, use at " + "least 2 columns.") + + # Check that diag parameter is a valid selection + if diag not in DIAG_CHOICES: + raise exceptions.PlotlyError("Make sure diag is set to " + "one of {}".format(DIAG_CHOICES)) + + # Check that colormap_types is a valid selection + if colormap_type not in VALID_COLORMAP_TYPES: + raise exceptions.PlotlyError("Must choose a valid colormap type. " + "Either 'cat' or 'seq' for a cate" + "gorical and sequential colormap " + "respectively.") + + # Check for not 'size' or 'color' in 'marker' of **kwargs + if 'marker' in kwargs: + FORBIDDEN_PARAMS = ['size', 'color', 'colorscale'] + if any(param in kwargs['marker'] for param in FORBIDDEN_PARAMS): + raise exceptions.PlotlyError("Your kwargs dictionary cannot " + "include the 'size', 'color' or " + "'colorscale' key words inside " + "the marker dict since 'size' is " + "already an argument of the " + "scatterplot matrix function and " + "both 'color' and 'colorscale " + "are set internally.") + + +def scatterplot(dataframe, headers, diag, size, height, width, title, + **kwargs): + """ + Refer to FigureFactory.create_scatterplotmatrix() for docstring + + Returns fig for scatterplotmatrix without index + + """ + dim = len(dataframe) + fig = make_subplots(rows=dim, cols=dim, print_grid=False) + trace_list = [] + # Insert traces into trace_list + for listy in dataframe: + for listx in dataframe: + if (listx == listy) and (diag == 'histogram'): + trace = graph_objs.Histogram( + x=listx, + showlegend=False + ) + elif (listx == listy) and (diag == 'box'): + trace = graph_objs.Box( + y=listx, + name=None, + showlegend=False + ) + else: + if 'marker' in kwargs: + kwargs['marker']['size'] = size + trace = graph_objs.Scatter( + x=listx, + y=listy, + mode='markers', + showlegend=False, + **kwargs + ) + trace_list.append(trace) + else: + trace = graph_objs.Scatter( + x=listx, + y=listy, + mode='markers', + marker=dict( + size=size), + showlegend=False, + **kwargs + ) + trace_list.append(trace) + + trace_index = 0 + indices = range(1, dim + 1) + for y_index in indices: + for x_index in indices: + fig.append_trace(trace_list[trace_index], + y_index, + x_index) + trace_index += 1 + + # Insert headers into the figure + for j in range(dim): + xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j) + fig['layout'][xaxis_key].update(title=headers[j]) + for j in range(dim): + yaxis_key = 'yaxis{}'.format(1 + (dim * j)) + fig['layout'][yaxis_key].update(title=headers[j]) + + fig['layout'].update( + height=height, width=width, + title=title, + showlegend=True + ) + + hide_tick_labels_from_box_subplots(fig) + + return fig + + +def scatterplot_dict(dataframe, headers, diag, size, + height, width, title, index, index_vals, + endpts, colormap, colormap_type, **kwargs): + """ + Refer to FigureFactory.create_scatterplotmatrix() for docstring + + Returns fig for scatterplotmatrix with both index and colormap picked. + Used if colormap is a dictionary with index values as keys pointing to + colors. Forces colormap_type to behave categorically because it would + not make sense colors are assigned to each index value and thus + implies that a categorical approach should be taken + + """ + + theme = colormap + dim = len(dataframe) + fig = make_subplots(rows=dim, cols=dim, print_grid=False) + trace_list = [] + legend_param = 0 + # Work over all permutations of list pairs + for listy in dataframe: + for listx in dataframe: + # create a dictionary for index_vals + unique_index_vals = {} + for name in index_vals: + if name not in unique_index_vals: + unique_index_vals[name] = [] + + # Fill all the rest of the names into the dictionary + for name in sorted(unique_index_vals.keys()): + new_listx = [] + new_listy = [] + for j in range(len(index_vals)): + if index_vals[j] == name: + new_listx.append(listx[j]) + new_listy.append(listy[j]) + # Generate trace with VISIBLE icon + if legend_param == 1: + if (listx == listy) and (diag == 'histogram'): + trace = graph_objs.Histogram( + x=new_listx, + marker=dict( + color=theme[name]), + showlegend=True + ) + elif (listx == listy) and (diag == 'box'): + trace = graph_objs.Box( + y=new_listx, + name=None, + marker=dict( + color=theme[name]), + showlegend=True + ) + else: + if 'marker' in kwargs: + kwargs['marker']['size'] = size + kwargs['marker']['color'] = theme[name] + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode='markers', + name=name, + showlegend=True, + **kwargs + ) + else: + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode='markers', + name=name, + marker=dict( + size=size, + color=theme[name]), + showlegend=True, + **kwargs + ) + # Generate trace with INVISIBLE icon + else: + if (listx == listy) and (diag == 'histogram'): + trace = graph_objs.Histogram( + x=new_listx, + marker=dict( + color=theme[name]), + showlegend=False + ) + elif (listx == listy) and (diag == 'box'): + trace = graph_objs.Box( + y=new_listx, + name=None, + marker=dict( + color=theme[name]), + showlegend=False + ) + else: + if 'marker' in kwargs: + kwargs['marker']['size'] = size + kwargs['marker']['color'] = theme[name] + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode='markers', + name=name, + showlegend=False, + **kwargs + ) + else: + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode='markers', + name=name, + marker=dict( + size=size, + color=theme[name]), + showlegend=False, + **kwargs + ) + # Push the trace into dictionary + unique_index_vals[name] = trace + trace_list.append(unique_index_vals) + legend_param += 1 + + trace_index = 0 + indices = range(1, dim + 1) + for y_index in indices: + for x_index in indices: + for name in sorted(trace_list[trace_index].keys()): + fig.append_trace( + trace_list[trace_index][name], + y_index, + x_index) + trace_index += 1 + + # Insert headers into the figure + for j in range(dim): + xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j) + fig['layout'][xaxis_key].update(title=headers[j]) + + for j in range(dim): + yaxis_key = 'yaxis{}'.format(1 + (dim * j)) + fig['layout'][yaxis_key].update(title=headers[j]) + + hide_tick_labels_from_box_subplots(fig) + + if diag == 'histogram': + fig['layout'].update( + height=height, width=width, + title=title, + showlegend=True, + barmode='stack') + return fig + + else: + fig['layout'].update( + height=height, width=width, + title=title, + showlegend=True) + return fig + + +def scatterplot_theme(dataframe, headers, diag, size, height, width, title, + index, index_vals, endpts, colormap, colormap_type, + **kwargs): + """ + Refer to FigureFactory.create_scatterplotmatrix() for docstring + + Returns fig for scatterplotmatrix with both index and colormap picked + + """ + + # Check if index is made of string values + if isinstance(index_vals[0], str): + unique_index_vals = [] + for name in index_vals: + if name not in unique_index_vals: + unique_index_vals.append(name) + n_colors_len = len(unique_index_vals) + + # Convert colormap to list of n RGB tuples + if colormap_type == 'seq': + foo = utils.color_parser(colormap, utils.unlabel_rgb) + foo = utils.n_colors(foo[0], foo[1], n_colors_len) + theme = utils.color_parser(foo, utils.label_rgb) + + if colormap_type == 'cat': + # leave list of colors the same way + theme = colormap + + dim = len(dataframe) + fig = make_subplots(rows=dim, cols=dim, print_grid=False) + trace_list = [] + legend_param = 0 + # Work over all permutations of list pairs + for listy in dataframe: + for listx in dataframe: + # create a dictionary for index_vals + unique_index_vals = {} + for name in index_vals: + if name not in unique_index_vals: + unique_index_vals[name] = [] + + c_indx = 0 # color index + # Fill all the rest of the names into the dictionary + for name in sorted(unique_index_vals.keys()): + new_listx = [] + new_listy = [] + for j in range(len(index_vals)): + if index_vals[j] == name: + new_listx.append(listx[j]) + new_listy.append(listy[j]) + # Generate trace with VISIBLE icon + if legend_param == 1: + if (listx == listy) and (diag == 'histogram'): + trace = graph_objs.Histogram( + x=new_listx, + marker=dict( + color=theme[c_indx]), + showlegend=True + ) + elif (listx == listy) and (diag == 'box'): + trace = graph_objs.Box( + y=new_listx, + name=None, + marker=dict( + color=theme[c_indx]), + showlegend=True + ) + else: + if 'marker' in kwargs: + kwargs['marker']['size'] = size + kwargs['marker']['color'] = theme[c_indx] + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode='markers', + name=name, + showlegend=True, + **kwargs + ) + else: + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode='markers', + name=name, + marker=dict( + size=size, + color=theme[c_indx]), + showlegend=True, + **kwargs + ) + # Generate trace with INVISIBLE icon + else: + if (listx == listy) and (diag == 'histogram'): + trace = graph_objs.Histogram( + x=new_listx, + marker=dict( + color=theme[c_indx]), + showlegend=False + ) + elif (listx == listy) and (diag == 'box'): + trace = graph_objs.Box( + y=new_listx, + name=None, + marker=dict( + color=theme[c_indx]), + showlegend=False + ) + else: + if 'marker' in kwargs: + kwargs['marker']['size'] = size + kwargs['marker']['color'] = theme[c_indx] + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode='markers', + name=name, + showlegend=False, + **kwargs + ) + else: + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode='markers', + name=name, + marker=dict( + size=size, + color=theme[c_indx]), + showlegend=False, + **kwargs + ) + # Push the trace into dictionary + unique_index_vals[name] = trace + if c_indx >= (len(theme) - 1): + c_indx = -1 + c_indx += 1 + trace_list.append(unique_index_vals) + legend_param += 1 + + trace_index = 0 + indices = range(1, dim + 1) + for y_index in indices: + for x_index in indices: + for name in sorted(trace_list[trace_index].keys()): + fig.append_trace( + trace_list[trace_index][name], + y_index, + x_index) + trace_index += 1 + + # Insert headers into the figure + for j in range(dim): + xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j) + fig['layout'][xaxis_key].update(title=headers[j]) + + for j in range(dim): + yaxis_key = 'yaxis{}'.format(1 + (dim * j)) + fig['layout'][yaxis_key].update(title=headers[j]) + + hide_tick_labels_from_box_subplots(fig) + + if diag == 'histogram': + fig['layout'].update( + height=height, width=width, + title=title, + showlegend=True, + barmode='stack') + return fig + + elif diag == 'box': + fig['layout'].update( + height=height, width=width, + title=title, + showlegend=True) + return fig + + else: + fig['layout'].update( + height=height, width=width, + title=title, + showlegend=True) + return fig + + else: + if endpts: + intervals = endpts_to_intervals(endpts) + + # Convert colormap to list of n RGB tuples + if colormap_type == 'seq': + foo = utils.color_parser(colormap, utils.unlabel_rgb) + foo = utils.n_colors(foo[0], foo[1], len(intervals)) + theme = utils.color_parser(foo, utils.label_rgb) + + if colormap_type == 'cat': + # leave list of colors the same way + theme = colormap + + dim = len(dataframe) + fig = make_subplots(rows=dim, cols=dim, print_grid=False) + trace_list = [] + legend_param = 0 + # Work over all permutations of list pairs + for listy in dataframe: + for listx in dataframe: + interval_labels = {} + for interval in intervals: + interval_labels[str(interval)] = [] + + c_indx = 0 # color index + # Fill all the rest of the names into the dictionary + for interval in intervals: + new_listx = [] + new_listy = [] + for j in range(len(index_vals)): + if interval[0] < index_vals[j] <= interval[1]: + new_listx.append(listx[j]) + new_listy.append(listy[j]) + # Generate trace with VISIBLE icon + if legend_param == 1: + if (listx == listy) and (diag == 'histogram'): + trace = graph_objs.Histogram( + x=new_listx, + marker=dict( + color=theme[c_indx]), + showlegend=True + ) + elif (listx == listy) and (diag == 'box'): + trace = graph_objs.Box( + y=new_listx, + name=None, + marker=dict( + color=theme[c_indx]), + showlegend=True + ) + else: + if 'marker' in kwargs: + kwargs['marker']['size'] = size + (kwargs['marker'] + ['color']) = theme[c_indx] + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode='markers', + name=str(interval), + showlegend=True, + **kwargs + ) + else: + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode='markers', + name=str(interval), + marker=dict( + size=size, + color=theme[c_indx]), + showlegend=True, + **kwargs + ) + # Generate trace with INVISIBLE icon + else: + if (listx == listy) and (diag == 'histogram'): + trace = graph_objs.Histogram( + x=new_listx, + marker=dict( + color=theme[c_indx]), + showlegend=False + ) + elif (listx == listy) and (diag == 'box'): + trace = graph_objs.Box( + y=new_listx, + name=None, + marker=dict( + color=theme[c_indx]), + showlegend=False + ) + else: + if 'marker' in kwargs: + kwargs['marker']['size'] = size + (kwargs['marker'] + ['color']) = theme[c_indx] + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode='markers', + name=str(interval), + showlegend=False, + **kwargs + ) + else: + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode='markers', + name=str(interval), + marker=dict( + size=size, + color=theme[c_indx]), + showlegend=False, + **kwargs + ) + # Push the trace into dictionary + interval_labels[str(interval)] = trace + if c_indx >= (len(theme) - 1): + c_indx = -1 + c_indx += 1 + trace_list.append(interval_labels) + legend_param += 1 + + trace_index = 0 + indices = range(1, dim + 1) + for y_index in indices: + for x_index in indices: + for interval in intervals: + fig.append_trace( + trace_list[trace_index][str(interval)], + y_index, + x_index) + trace_index += 1 + + # Insert headers into the figure + for j in range(dim): + xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j) + fig['layout'][xaxis_key].update(title=headers[j]) + for j in range(dim): + yaxis_key = 'yaxis{}'.format(1 + (dim * j)) + fig['layout'][yaxis_key].update(title=headers[j]) + + hide_tick_labels_from_box_subplots(fig) + + if diag == 'histogram': + fig['layout'].update( + height=height, width=width, + title=title, + showlegend=True, + barmode='stack') + return fig + + elif diag == 'box': + fig['layout'].update( + height=height, width=width, + title=title, + showlegend=True) + return fig + + else: + fig['layout'].update( + height=height, width=width, + title=title, + showlegend=True) + return fig + + else: + theme = colormap + + # add a copy of rgb color to theme if it contains one color + if len(theme) <= 1: + theme.append(theme[0]) + + color = [] + for incr in range(len(theme)): + color.append([1. / (len(theme) - 1) * incr, theme[incr]]) + + dim = len(dataframe) + fig = make_subplots(rows=dim, cols=dim, print_grid=False) + trace_list = [] + legend_param = 0 + # Run through all permutations of list pairs + for listy in dataframe: + for listx in dataframe: + # Generate trace with VISIBLE icon + if legend_param == 1: + if (listx == listy) and (diag == 'histogram'): + trace = graph_objs.Histogram( + x=listx, + marker=dict( + color=theme[0]), + showlegend=False + ) + elif (listx == listy) and (diag == 'box'): + trace = graph_objs.Box( + y=listx, + marker=dict( + color=theme[0]), + showlegend=False + ) + else: + if 'marker' in kwargs: + kwargs['marker']['size'] = size + kwargs['marker']['color'] = index_vals + kwargs['marker']['colorscale'] = color + kwargs['marker']['showscale'] = True + trace = graph_objs.Scatter( + x=listx, + y=listy, + mode='markers', + showlegend=False, + **kwargs + ) + else: + trace = graph_objs.Scatter( + x=listx, + y=listy, + mode='markers', + marker=dict( + size=size, + color=index_vals, + colorscale=color, + showscale=True), + showlegend=False, + **kwargs + ) + # Generate trace with INVISIBLE icon + else: + if (listx == listy) and (diag == 'histogram'): + trace = graph_objs.Histogram( + x=listx, + marker=dict( + color=theme[0]), + showlegend=False + ) + elif (listx == listy) and (diag == 'box'): + trace = graph_objs.Box( + y=listx, + marker=dict( + color=theme[0]), + showlegend=False + ) + else: + if 'marker' in kwargs: + kwargs['marker']['size'] = size + kwargs['marker']['color'] = index_vals + kwargs['marker']['colorscale'] = color + kwargs['marker']['showscale'] = False + trace = graph_objs.Scatter( + x=listx, + y=listy, + mode='markers', + showlegend=False, + **kwargs + ) + else: + trace = graph_objs.Scatter( + x=listx, + y=listy, + mode='markers', + marker=dict( + size=size, + color=index_vals, + colorscale=color, + showscale=False), + showlegend=False, + **kwargs + ) + # Push the trace into list + trace_list.append(trace) + legend_param += 1 + + trace_index = 0 + indices = range(1, dim + 1) + for y_index in indices: + for x_index in indices: + fig.append_trace(trace_list[trace_index], + y_index, + x_index) + trace_index += 1 + + # Insert headers into the figure + for j in range(dim): + xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j) + fig['layout'][xaxis_key].update(title=headers[j]) + for j in range(dim): + yaxis_key = 'yaxis{}'.format(1 + (dim * j)) + fig['layout'][yaxis_key].update(title=headers[j]) + + hide_tick_labels_from_box_subplots(fig) + + if diag == 'histogram': + fig['layout'].update( + height=height, width=width, + title=title, + showlegend=True, + barmode='stack') + return fig + + elif diag == 'box': + fig['layout'].update( + height=height, width=width, + title=title, + showlegend=True) + return fig + + else: + fig['layout'].update( + height=height, width=width, + title=title, + showlegend=True) + return fig + + +def create_scatterplotmatrix(df, index=None, endpts=None, diag='scatter', + height=500, width=500, size=6, + title='Scatterplot Matrix', colormap=None, + colormap_type='cat', dataframe=None, + headers=None, index_vals=None, **kwargs): + """ + Returns data for a scatterplot matrix. + + :param (array) df: array of the data with column headers + :param (str) index: name of the index column in data array + :param (list|tuple) endpts: takes an increasing sequece of numbers + that defines intervals on the real line. They are used to group + the entries in an index of numbers into their corresponding + interval and therefore can be treated as categorical data + :param (str) diag: sets the chart type for the main diagonal plots. + The options are 'scatter', 'histogram' and 'box'. + :param (int|float) height: sets the height of the chart + :param (int|float) width: sets the width of the chart + :param (float) size: sets the marker size (in px) + :param (str) title: the title label of the scatterplot matrix + :param (str|tuple|list|dict) colormap: either a plotly scale name, + an rgb or hex color, a color tuple, a list of colors or a + dictionary. An rgb color is of the form 'rgb(x, y, z)' where + x, y and z belong to the interval [0, 255] and a color tuple is a + tuple of the form (a, b, c) where a, b and c belong to [0, 1]. + If colormap is a list, it must contain valid color types as its + members. + If colormap is a dictionary, all the string entries in + the index column must be a key in colormap. In this case, the + colormap_type is forced to 'cat' or categorical + :param (str) colormap_type: determines how colormap is interpreted. + Valid choices are 'seq' (sequential) and 'cat' (categorical). If + 'seq' is selected, only the first two colors in colormap will be + considered (when colormap is a list) and the index values will be + linearly interpolated between those two colors. This option is + forced if all index values are numeric. + If 'cat' is selected, a color from colormap will be assigned to + each category from index, including the intervals if endpts is + being used + :param (dict) **kwargs: a dictionary of scatterplot arguments + The only forbidden parameters are 'size', 'color' and + 'colorscale' in 'marker' + + Example 1: Vanilla Scatterplot Matrix + ``` + import plotly.plotly as py + from plotly.graph_objs import graph_objs + from plotly.figure_factory import create_scatterplotmatrix + + import numpy as np + import pandas as pd + + # Create dataframe + df = pd.DataFrame(np.random.randn(10, 2), + columns=['Column 1', 'Column 2']) + + # Create scatterplot matrix + fig = create_scatterplotmatrix(df) + + # Plot + py.iplot(fig, filename='Vanilla Scatterplot Matrix') + ``` + + Example 2: Indexing a Column + ``` + import plotly.plotly as py + from plotly.graph_objs import graph_objs + from plotly.figure_factory import create_scatterplotmatrix + + import numpy as np + import pandas as pd + + # Create dataframe with index + df = pd.DataFrame(np.random.randn(10, 2), + columns=['A', 'B']) + + # Add another column of strings to the dataframe + df['Fruit'] = pd.Series(['apple', 'apple', 'grape', 'apple', 'apple', + 'grape', 'pear', 'pear', 'apple', 'pear']) + + # Create scatterplot matrix + fig = create_scatterplotmatrix(df, index='Fruit', size=10) + + # Plot + py.iplot(fig, filename = 'Scatterplot Matrix with Index') + ``` + + Example 3: Styling the Diagonal Subplots + ``` + import plotly.plotly as py + from plotly.graph_objs import graph_objs + from plotly.figure_factory import create_scatterplotmatrix + + import numpy as np + import pandas as pd + + # Create dataframe with index + df = pd.DataFrame(np.random.randn(10, 4), + columns=['A', 'B', 'C', 'D']) + + # Add another column of strings to the dataframe + df['Fruit'] = pd.Series(['apple', 'apple', 'grape', 'apple', 'apple', + 'grape', 'pear', 'pear', 'apple', 'pear']) + + # Create scatterplot matrix + fig = create_scatterplotmatrix(df, diag='box', index='Fruit', height=1000, + width=1000) + + # Plot + py.iplot(fig, filename = 'Scatterplot Matrix - Diagonal Styling') + ``` + + Example 4: Use a Theme to Style the Subplots + ``` + import plotly.plotly as py + from plotly.graph_objs import graph_objs + from plotly.figure_factory import create_scatterplotmatrix + + import numpy as np + import pandas as pd + + # Create dataframe with random data + df = pd.DataFrame(np.random.randn(100, 3), + columns=['A', 'B', 'C']) + + # Create scatterplot matrix using a built-in + # Plotly palette scale and indexing column 'A' + fig = create_scatterplotmatrix(df, diag='histogram', index='A', + colormap='Blues', height=800, width=800) + + # Plot + py.iplot(fig, filename = 'Scatterplot Matrix - Colormap Theme') + ``` + + Example 5: Example 4 with Interval Factoring + ``` + import plotly.plotly as py + from plotly.graph_objs import graph_objs + from plotly.figure_factory import create_scatterplotmatrix + + import numpy as np + import pandas as pd + + # Create dataframe with random data + df = pd.DataFrame(np.random.randn(100, 3), + columns=['A', 'B', 'C']) + + # Create scatterplot matrix using a list of 2 rgb tuples + # and endpoints at -1, 0 and 1 + fig = create_scatterplotmatrix(df, diag='histogram', index='A', + colormap=['rgb(140, 255, 50)', + 'rgb(170, 60, 115)', '#6c4774', + (0.5, 0.1, 0.8)], + endpts=[-1, 0, 1], height=800, width=800) + + # Plot + py.iplot(fig, filename = 'Scatterplot Matrix - Intervals') + ``` + + Example 6: Using the colormap as a Dictionary + ``` + import plotly.plotly as py + from plotly.graph_objs import graph_objs + from plotly.figure_factory import create_scatterplotmatrix + + import numpy as np + import pandas as pd + import random + + # Create dataframe with random data + df = pd.DataFrame(np.random.randn(100, 3), + columns=['Column A', + 'Column B', + 'Column C']) + + # Add new color column to dataframe + new_column = [] + strange_colors = ['turquoise', 'limegreen', 'goldenrod'] + + for j in range(100): + new_column.append(random.choice(strange_colors)) + df['Colors'] = pd.Series(new_column, index=df.index) + + # Create scatterplot matrix using a dictionary of hex color values + # which correspond to actual color names in 'Colors' column + fig = create_scatterplotmatrix( + df, diag='box', index='Colors', + colormap= dict( + turquoise = '#00F5FF', + limegreen = '#32CD32', + goldenrod = '#DAA520' + ), + colormap_type='cat', + height=800, width=800 + ) + + # Plot + py.iplot(fig, filename = 'Scatterplot Matrix - colormap dictionary ') + ``` + """ + # TODO: protected until #282 + if dataframe is None: + dataframe = [] + if headers is None: + headers = [] + if index_vals is None: + index_vals = [] + + validate_scatterplotmatrix(df, index, diag, colormap_type, **kwargs) + + # Validate colormap + if isinstance(colormap, dict): + colormap = utils.validate_colors_dict(colormap, 'rgb') + else: + colormap = utils.validate_colors(colormap, 'rgb') + + if not index: + for name in df: + headers.append(name) + for name in headers: + dataframe.append(df[name].values.tolist()) + # Check for same data-type in df columns + utils.validate_dataframe(dataframe) + figure = scatterplot(dataframe, headers, diag, size, height, width, + title, **kwargs) + return figure + else: + # Validate index selection + if index not in df: + raise exceptions.PlotlyError("Make sure you set the index " + "input variable to one of the " + "column names of your " + "dataframe.") + index_vals = df[index].values.tolist() + for name in df: + if name != index: + headers.append(name) + for name in headers: + dataframe.append(df[name].values.tolist()) + + # check for same data-type in each df column + utils.validate_dataframe(dataframe) + utils.validate_index(index_vals) + + # check if all colormap keys are in the index + # if colormap is a dictionary + if isinstance(colormap, dict): + for key in colormap: + if not all(index in colormap for index in index_vals): + raise exceptions.PlotlyError("If colormap is a " + "dictionary, all the " + "names in the index " + "must be keys.") + figure = scatterplot_dict( + dataframe, headers, diag, size, height, width, title, + index, index_vals, endpts, colormap, colormap_type, + **kwargs + ) + return figure + + else: + figure = scatterplot_theme( + dataframe, headers, diag, size, height, width, title, + index, index_vals, endpts, colormap, colormap_type, + **kwargs + ) + return figure diff --git a/plotly/figure_factory/_streamline.py b/plotly/figure_factory/_streamline.py new file mode 100644 index 00000000000..ddc14778c43 --- /dev/null +++ b/plotly/figure_factory/_streamline.py @@ -0,0 +1,411 @@ +from __future__ import absolute_import + +import math + +from plotly import exceptions, optional_imports +from plotly.figure_factory import utils +from plotly.graph_objs import graph_objs + +np = optional_imports.get_module('numpy') + + +def validate_streamline(x, y): + """ + Streamline-specific validations + + Specifically, this checks that x and y are both evenly spaced, + and that the package numpy is available. + + See FigureFactory.create_streamline() for params + + :raises: (ImportError) If numpy is not available. + :raises: (PlotlyError) If x is not evenly spaced. + :raises: (PlotlyError) If y is not evenly spaced. + """ + if np is False: + raise ImportError("FigureFactory.create_streamline requires numpy") + for index in range(len(x) - 1): + if ((x[index + 1] - x[index]) - (x[1] - x[0])) > .0001: + raise exceptions.PlotlyError("x must be a 1 dimensional, " + "evenly spaced array") + for index in range(len(y) - 1): + if ((y[index + 1] - y[index]) - + (y[1] - y[0])) > .0001: + raise exceptions.PlotlyError("y must be a 1 dimensional, " + "evenly spaced array") + + +def create_streamline(x, y, u, v, density=1, angle=math.pi / 9, + arrow_scale=.09, **kwargs): + """ + Returns data for a streamline plot. + + :param (list|ndarray) x: 1 dimensional, evenly spaced list or array + :param (list|ndarray) y: 1 dimensional, evenly spaced list or array + :param (ndarray) u: 2 dimensional array + :param (ndarray) v: 2 dimensional array + :param (float|int) density: controls the density of streamlines in + plot. This is multiplied by 30 to scale similiarly to other + available streamline functions such as matplotlib. + Default = 1 + :param (angle in radians) angle: angle of arrowhead. Default = pi/9 + :param (float in [0,1]) arrow_scale: value to scale length of arrowhead + Default = .09 + :param kwargs: kwargs passed through plotly.graph_objs.Scatter + for more information on valid kwargs call + help(plotly.graph_objs.Scatter) + + :rtype (dict): returns a representation of streamline figure. + + Example 1: Plot simple streamline and increase arrow size + ``` + import plotly.plotly as py + from plotly.figure_factory import create_streamline + + import numpy as np + import math + + # Add data + x = np.linspace(-3, 3, 100) + y = np.linspace(-3, 3, 100) + Y, X = np.meshgrid(x, y) + u = -1 - X**2 + Y + v = 1 + X - Y**2 + u = u.T # Transpose + v = v.T # Transpose + + # Create streamline + fig = create_streamline(x, y, u, v, arrow_scale=.1) + + # Plot + py.plot(fig, filename='streamline') + ``` + + Example 2: from nbviewer.ipython.org/github/barbagroup/AeroPython + ``` + import plotly.plotly as py + from plotly.figure_factory import create_streamline + + import numpy as np + import math + + # Add data + N = 50 + x_start, x_end = -2.0, 2.0 + y_start, y_end = -1.0, 1.0 + x = np.linspace(x_start, x_end, N) + y = np.linspace(y_start, y_end, N) + X, Y = np.meshgrid(x, y) + ss = 5.0 + x_s, y_s = -1.0, 0.0 + + # Compute the velocity field on the mesh grid + u_s = ss/(2*np.pi) * (X-x_s)/((X-x_s)**2 + (Y-y_s)**2) + v_s = ss/(2*np.pi) * (Y-y_s)/((X-x_s)**2 + (Y-y_s)**2) + + # Create streamline + fig = create_streamline(x, y, u_s, v_s, density=2, name='streamline') + + # Add source point + point = Scatter(x=[x_s], y=[y_s], mode='markers', + marker=Marker(size=14), name='source point') + + # Plot + fig['data'].append(point) + py.plot(fig, filename='streamline') + ``` + """ + utils.validate_equal_length(x, y) + utils.validate_equal_length(u, v) + validate_streamline(x, y) + utils.validate_positive_scalars(density=density, arrow_scale=arrow_scale) + + streamline_x, streamline_y = _Streamline(x, y, u, v, + density, angle, + arrow_scale).sum_streamlines() + arrow_x, arrow_y = _Streamline(x, y, u, v, + density, angle, + arrow_scale).get_streamline_arrows() + + streamline = graph_objs.Scatter(x=streamline_x + arrow_x, + y=streamline_y + arrow_y, + mode='lines', **kwargs) + + data = [streamline] + layout = graph_objs.Layout(hovermode='closest') + + return graph_objs.Figure(data=data, layout=layout) + + +class _Streamline(object): + """ + Refer to FigureFactory.create_streamline() for docstring + """ + def __init__(self, x, y, u, v, + density, angle, + arrow_scale, **kwargs): + self.x = np.array(x) + self.y = np.array(y) + self.u = np.array(u) + self.v = np.array(v) + self.angle = angle + self.arrow_scale = arrow_scale + self.density = int(30 * density) # Scale similarly to other functions + self.delta_x = self.x[1] - self.x[0] + self.delta_y = self.y[1] - self.y[0] + self.val_x = self.x + self.val_y = self.y + + # Set up spacing + self.blank = np.zeros((self.density, self.density)) + self.spacing_x = len(self.x) / float(self.density - 1) + self.spacing_y = len(self.y) / float(self.density - 1) + self.trajectories = [] + + # Rescale speed onto axes-coordinates + self.u = self.u / (self.x[-1] - self.x[0]) + self.v = self.v / (self.y[-1] - self.y[0]) + self.speed = np.sqrt(self.u ** 2 + self.v ** 2) + + # Rescale u and v for integrations. + self.u *= len(self.x) + self.v *= len(self.y) + self.st_x = [] + self.st_y = [] + self.get_streamlines() + streamline_x, streamline_y = self.sum_streamlines() + arrows_x, arrows_y = self.get_streamline_arrows() + + def blank_pos(self, xi, yi): + """ + Set up positions for trajectories to be used with rk4 function. + """ + return (int((xi / self.spacing_x) + 0.5), + int((yi / self.spacing_y) + 0.5)) + + def value_at(self, a, xi, yi): + """ + Set up for RK4 function, based on Bokeh's streamline code + """ + if isinstance(xi, np.ndarray): + self.x = xi.astype(np.int) + self.y = yi.astype(np.int) + else: + self.val_x = np.int(xi) + self.val_y = np.int(yi) + a00 = a[self.val_y, self.val_x] + a01 = a[self.val_y, self.val_x + 1] + a10 = a[self.val_y + 1, self.val_x] + a11 = a[self.val_y + 1, self.val_x + 1] + xt = xi - self.val_x + yt = yi - self.val_y + a0 = a00 * (1 - xt) + a01 * xt + a1 = a10 * (1 - xt) + a11 * xt + return a0 * (1 - yt) + a1 * yt + + def rk4_integrate(self, x0, y0): + """ + RK4 forward and back trajectories from the initial conditions. + + Adapted from Bokeh's streamline -uses Runge-Kutta method to fill + x and y trajectories then checks length of traj (s in units of axes) + """ + def f(xi, yi): + dt_ds = 1. / self.value_at(self.speed, xi, yi) + ui = self.value_at(self.u, xi, yi) + vi = self.value_at(self.v, xi, yi) + return ui * dt_ds, vi * dt_ds + + def g(xi, yi): + dt_ds = 1. / self.value_at(self.speed, xi, yi) + ui = self.value_at(self.u, xi, yi) + vi = self.value_at(self.v, xi, yi) + return -ui * dt_ds, -vi * dt_ds + + check = lambda xi, yi: (0 <= xi < len(self.x) - 1 and + 0 <= yi < len(self.y) - 1) + xb_changes = [] + yb_changes = [] + + def rk4(x0, y0, f): + ds = 0.01 + stotal = 0 + xi = x0 + yi = y0 + xb, yb = self.blank_pos(xi, yi) + xf_traj = [] + yf_traj = [] + while check(xi, yi): + xf_traj.append(xi) + yf_traj.append(yi) + try: + k1x, k1y = f(xi, yi) + k2x, k2y = f(xi + .5 * ds * k1x, yi + .5 * ds * k1y) + k3x, k3y = f(xi + .5 * ds * k2x, yi + .5 * ds * k2y) + k4x, k4y = f(xi + ds * k3x, yi + ds * k3y) + except IndexError: + break + xi += ds * (k1x + 2 * k2x + 2 * k3x + k4x) / 6. + yi += ds * (k1y + 2 * k2y + 2 * k3y + k4y) / 6. + if not check(xi, yi): + break + stotal += ds + new_xb, new_yb = self.blank_pos(xi, yi) + if new_xb != xb or new_yb != yb: + if self.blank[new_yb, new_xb] == 0: + self.blank[new_yb, new_xb] = 1 + xb_changes.append(new_xb) + yb_changes.append(new_yb) + xb = new_xb + yb = new_yb + else: + break + if stotal > 2: + break + return stotal, xf_traj, yf_traj + + sf, xf_traj, yf_traj = rk4(x0, y0, f) + sb, xb_traj, yb_traj = rk4(x0, y0, g) + stotal = sf + sb + x_traj = xb_traj[::-1] + xf_traj[1:] + y_traj = yb_traj[::-1] + yf_traj[1:] + + if len(x_traj) < 1: + return None + if stotal > .2: + initxb, inityb = self.blank_pos(x0, y0) + self.blank[inityb, initxb] = 1 + return x_traj, y_traj + else: + for xb, yb in zip(xb_changes, yb_changes): + self.blank[yb, xb] = 0 + return None + + def traj(self, xb, yb): + """ + Integrate trajectories + + :param (int) xb: results of passing xi through self.blank_pos + :param (int) xy: results of passing yi through self.blank_pos + + Calculate each trajectory based on rk4 integrate method. + """ + + if xb < 0 or xb >= self.density or yb < 0 or yb >= self.density: + return + if self.blank[yb, xb] == 0: + t = self.rk4_integrate(xb * self.spacing_x, yb * self.spacing_y) + if t is not None: + self.trajectories.append(t) + + def get_streamlines(self): + """ + Get streamlines by building trajectory set. + """ + for indent in range(self.density // 2): + for xi in range(self.density - 2 * indent): + self.traj(xi + indent, indent) + self.traj(xi + indent, self.density - 1 - indent) + self.traj(indent, xi + indent) + self.traj(self.density - 1 - indent, xi + indent) + + self.st_x = [np.array(t[0]) * self.delta_x + self.x[0] for t in + self.trajectories] + self.st_y = [np.array(t[1]) * self.delta_y + self.y[0] for t in + self.trajectories] + + for index in range(len(self.st_x)): + self.st_x[index] = self.st_x[index].tolist() + self.st_x[index].append(np.nan) + + for index in range(len(self.st_y)): + self.st_y[index] = self.st_y[index].tolist() + self.st_y[index].append(np.nan) + + def get_streamline_arrows(self): + """ + Makes an arrow for each streamline. + + Gets angle of streamline at 1/3 mark and creates arrow coordinates + based off of user defined angle and arrow_scale. + + :param (array) st_x: x-values for all streamlines + :param (array) st_y: y-values for all streamlines + :param (angle in radians) angle: angle of arrowhead. Default = pi/9 + :param (float in [0,1]) arrow_scale: value to scale length of arrowhead + Default = .09 + :rtype (list, list) arrows_x: x-values to create arrowhead and + arrows_y: y-values to create arrowhead + """ + arrow_end_x = np.empty((len(self.st_x))) + arrow_end_y = np.empty((len(self.st_y))) + arrow_start_x = np.empty((len(self.st_x))) + arrow_start_y = np.empty((len(self.st_y))) + for index in range(len(self.st_x)): + arrow_end_x[index] = (self.st_x[index] + [int(len(self.st_x[index]) / 3)]) + arrow_start_x[index] = (self.st_x[index] + [(int(len(self.st_x[index]) / 3)) - 1]) + arrow_end_y[index] = (self.st_y[index] + [int(len(self.st_y[index]) / 3)]) + arrow_start_y[index] = (self.st_y[index] + [(int(len(self.st_y[index]) / 3)) - 1]) + + dif_x = arrow_end_x - arrow_start_x + dif_y = arrow_end_y - arrow_start_y + + streamline_ang = np.arctan(dif_y / dif_x) + + ang1 = streamline_ang + (self.angle) + ang2 = streamline_ang - (self.angle) + + seg1_x = np.cos(ang1) * self.arrow_scale + seg1_y = np.sin(ang1) * self.arrow_scale + seg2_x = np.cos(ang2) * self.arrow_scale + seg2_y = np.sin(ang2) * self.arrow_scale + + point1_x = np.empty((len(dif_x))) + point1_y = np.empty((len(dif_y))) + point2_x = np.empty((len(dif_x))) + point2_y = np.empty((len(dif_y))) + + for index in range(len(dif_x)): + if dif_x[index] >= 0: + point1_x[index] = arrow_end_x[index] - seg1_x[index] + point1_y[index] = arrow_end_y[index] - seg1_y[index] + point2_x[index] = arrow_end_x[index] - seg2_x[index] + point2_y[index] = arrow_end_y[index] - seg2_y[index] + else: + point1_x[index] = arrow_end_x[index] + seg1_x[index] + point1_y[index] = arrow_end_y[index] + seg1_y[index] + point2_x[index] = arrow_end_x[index] + seg2_x[index] + point2_y[index] = arrow_end_y[index] + seg2_y[index] + + space = np.empty((len(point1_x))) + space[:] = np.nan + + # Combine arrays into matrix + arrows_x = np.matrix([point1_x, arrow_end_x, point2_x, space]) + arrows_x = np.array(arrows_x) + arrows_x = arrows_x.flatten('F') + arrows_x = arrows_x.tolist() + + # Combine arrays into matrix + arrows_y = np.matrix([point1_y, arrow_end_y, point2_y, space]) + arrows_y = np.array(arrows_y) + arrows_y = arrows_y.flatten('F') + arrows_y = arrows_y.tolist() + + return arrows_x, arrows_y + + def sum_streamlines(self): + """ + Makes all streamlines readable as a single trace. + + :rtype (list, list): streamline_x: all x values for each streamline + combined into single list and streamline_y: all y values for each + streamline combined into single list + """ + streamline_x = sum(self.st_x, []) + streamline_y = sum(self.st_y, []) + return streamline_x, streamline_y diff --git a/plotly/figure_factory/_table.py b/plotly/figure_factory/_table.py new file mode 100644 index 00000000000..001dc9ea0ec --- /dev/null +++ b/plotly/figure_factory/_table.py @@ -0,0 +1,232 @@ +from __future__ import absolute_import + +from plotly import exceptions, optional_imports +from plotly.graph_objs import graph_objs + +pd = optional_imports.get_module('pandas') + + +def validate_table(table_text, font_colors): + """ + Table-specific validations + + Check that font_colors is supplied correctly (1, 3, or len(text) + colors). + + :raises: (PlotlyError) If font_colors is supplied incorretly. + + See FigureFactory.create_table() for params + """ + font_colors_len_options = [1, 3, len(table_text)] + if len(font_colors) not in font_colors_len_options: + raise exceptions.PlotlyError("Oops, font_colors should be a list " + "of length 1, 3 or len(text)") + + +def create_table(table_text, colorscale=None, font_colors=None, + index=False, index_title='', annotation_offset=.45, + height_constant=30, hoverinfo='none', **kwargs): + """ + BETA function that creates data tables + + :param (pandas.Dataframe | list[list]) text: data for table. + :param (str|list[list]) colorscale: Colorscale for table where the + color at value 0 is the header color, .5 is the first table color + and 1 is the second table color. (Set .5 and 1 to avoid the striped + table effect). Default=[[0, '#66b2ff'], [.5, '#d9d9d9'], + [1, '#ffffff']] + :param (list) font_colors: Color for fonts in table. Can be a single + color, three colors, or a color for each row in the table. + Default=['#000000'] (black text for the entire table) + :param (int) height_constant: Constant multiplied by # of rows to + create table height. Default=30. + :param (bool) index: Create (header-colored) index column index from + Pandas dataframe or list[0] for each list in text. Default=False. + :param (string) index_title: Title for index column. Default=''. + :param kwargs: kwargs passed through plotly.graph_objs.Heatmap. + These kwargs describe other attributes about the annotated Heatmap + trace such as the colorscale. For more information on valid kwargs + call help(plotly.graph_objs.Heatmap) + + Example 1: Simple Plotly Table + ``` + import plotly.plotly as py + from plotly.figure_factory import create_table + + text = [['Country', 'Year', 'Population'], + ['US', 2000, 282200000], + ['Canada', 2000, 27790000], + ['US', 2010, 309000000], + ['Canada', 2010, 34000000]] + + table = create_table(text) + py.iplot(table) + ``` + + Example 2: Table with Custom Coloring + ``` + import plotly.plotly as py + from plotly.figure_factory import create_table + + text = [['Country', 'Year', 'Population'], + ['US', 2000, 282200000], + ['Canada', 2000, 27790000], + ['US', 2010, 309000000], + ['Canada', 2010, 34000000]] + + table = create_table(text, + colorscale=[[0, '#000000'], + [.5, '#80beff'], + [1, '#cce5ff']], + font_colors=['#ffffff', '#000000', + '#000000']) + py.iplot(table) + ``` + Example 3: Simple Plotly Table with Pandas + ``` + import plotly.plotly as py + from plotly.figure_factory import create_table + + import pandas as pd + + df = pd.read_csv('http://www.stat.ubc.ca/~jenny/notOcto/STAT545A/examples/gapminder/data/gapminderDataFiveYear.txt', sep='\t') + df_p = df[0:25] + + table_simple = create_table(df_p) + py.iplot(table_simple) + ``` + """ + + # Avoiding mutables in the call signature + colorscale = \ + colorscale if colorscale is not None else [[0, '#00083e'], + [.5, '#ededee'], + [1, '#ffffff']] + font_colors = font_colors if font_colors is not None else ['#ffffff', + '#000000', + '#000000'] + + validate_table(table_text, font_colors) + table_matrix = _Table(table_text, colorscale, font_colors, index, + index_title, annotation_offset, + **kwargs).get_table_matrix() + annotations = _Table(table_text, colorscale, font_colors, index, + index_title, annotation_offset, + **kwargs).make_table_annotations() + + trace = dict(type='heatmap', z=table_matrix, opacity=.75, + colorscale=colorscale, showscale=False, + hoverinfo=hoverinfo, **kwargs) + + data = [trace] + layout = dict(annotations=annotations, + height=len(table_matrix) * height_constant + 50, + margin=dict(t=0, b=0, r=0, l=0), + yaxis=dict(autorange='reversed', zeroline=False, + gridwidth=2, ticks='', dtick=1, tick0=.5, + showticklabels=False), + xaxis=dict(zeroline=False, gridwidth=2, ticks='', + dtick=1, tick0=-0.5, showticklabels=False)) + return graph_objs.Figure(data=data, layout=layout) + + +class _Table(object): + """ + Refer to TraceFactory.create_table() for docstring + """ + def __init__(self, table_text, colorscale, font_colors, index, + index_title, annotation_offset, **kwargs): + if pd and isinstance(table_text, pd.DataFrame): + headers = table_text.columns.tolist() + table_text_index = table_text.index.tolist() + table_text = table_text.values.tolist() + table_text.insert(0, headers) + if index: + table_text_index.insert(0, index_title) + for i in range(len(table_text)): + table_text[i].insert(0, table_text_index[i]) + self.table_text = table_text + self.colorscale = colorscale + self.font_colors = font_colors + self.index = index + self.annotation_offset = annotation_offset + self.x = range(len(table_text[0])) + self.y = range(len(table_text)) + + def get_table_matrix(self): + """ + Create z matrix to make heatmap with striped table coloring + + :rtype (list[list]) table_matrix: z matrix to make heatmap with striped + table coloring. + """ + header = [0] * len(self.table_text[0]) + odd_row = [.5] * len(self.table_text[0]) + even_row = [1] * len(self.table_text[0]) + table_matrix = [None] * len(self.table_text) + table_matrix[0] = header + for i in range(1, len(self.table_text), 2): + table_matrix[i] = odd_row + for i in range(2, len(self.table_text), 2): + table_matrix[i] = even_row + if self.index: + for array in table_matrix: + array[0] = 0 + return table_matrix + + def get_table_font_color(self): + """ + Fill font-color array. + + Table text color can vary by row so this extends a single color or + creates an array to set a header color and two alternating colors to + create the striped table pattern. + + :rtype (list[list]) all_font_colors: list of font colors for each row + in table. + """ + if len(self.font_colors) == 1: + all_font_colors = self.font_colors*len(self.table_text) + elif len(self.font_colors) == 3: + all_font_colors = list(range(len(self.table_text))) + all_font_colors[0] = self.font_colors[0] + for i in range(1, len(self.table_text), 2): + all_font_colors[i] = self.font_colors[1] + for i in range(2, len(self.table_text), 2): + all_font_colors[i] = self.font_colors[2] + elif len(self.font_colors) == len(self.table_text): + all_font_colors = self.font_colors + else: + all_font_colors = ['#000000']*len(self.table_text) + return all_font_colors + + def make_table_annotations(self): + """ + Generate annotations to fill in table text + + :rtype (list) annotations: list of annotations for each cell of the + table. + """ + table_matrix = _Table.get_table_matrix(self) + all_font_colors = _Table.get_table_font_color(self) + annotations = [] + for n, row in enumerate(self.table_text): + for m, val in enumerate(row): + # Bold text in header and index + format_text = ('' + str(val) + '' if n == 0 or + self.index and m < 1 else str(val)) + # Match font color of index to font color of header + font_color = (self.font_colors[0] if self.index and m == 0 + else all_font_colors[n]) + annotations.append( + graph_objs.Annotation( + text=format_text, + x=self.x[m] - self.annotation_offset, + y=self.y[n], + xref='x1', + yref='y1', + align="left", + xanchor="left", + font=dict(color=font_color), + showarrow=False)) + return annotations diff --git a/plotly/figure_factory/_trisurf.py b/plotly/figure_factory/_trisurf.py new file mode 100644 index 00000000000..d2b58420471 --- /dev/null +++ b/plotly/figure_factory/_trisurf.py @@ -0,0 +1,488 @@ +from __future__ import absolute_import + +from plotly import colors, exceptions, optional_imports +from plotly.graph_objs import graph_objs + +np = optional_imports.get_module('numpy') + + +def map_face2color(face, colormap, scale, vmin, vmax): + """ + Normalize facecolor values by vmin/vmax and return rgb-color strings + + This function takes a tuple color along with a colormap and a minimum + (vmin) and maximum (vmax) range of possible mean distances for the + given parametrized surface. It returns an rgb color based on the mean + distance between vmin and vmax + + """ + if vmin >= vmax: + raise exceptions.PlotlyError("Incorrect relation between vmin " + "and vmax. The vmin value cannot be " + "bigger than or equal to the value " + "of vmax.") + if len(colormap) == 1: + # color each triangle face with the same color in colormap + face_color = colormap[0] + face_color = colors.convert_to_RGB_255(face_color) + face_color = colors.label_rgb(face_color) + return face_color + if face == vmax: + # pick last color in colormap + face_color = colormap[-1] + face_color = colors.convert_to_RGB_255(face_color) + face_color = colors.label_rgb(face_color) + return face_color + else: + if scale is None: + # find the normalized distance t of a triangle face between + # vmin and vmax where the distance is between 0 and 1 + t = (face - vmin) / float((vmax - vmin)) + low_color_index = int(t / (1./(len(colormap) - 1))) + + face_color = colors.find_intermediate_color( + colormap[low_color_index], + colormap[low_color_index + 1], + t * (len(colormap) - 1) - low_color_index + ) + + face_color = colors.convert_to_RGB_255(face_color) + face_color = colors.label_rgb(face_color) + else: + # find the face color for a non-linearly interpolated scale + t = (face - vmin) / float((vmax - vmin)) + + low_color_index = 0 + for k in range(len(scale) - 1): + if scale[k] <= t < scale[k+1]: + break + low_color_index += 1 + + low_scale_val = scale[low_color_index] + high_scale_val = scale[low_color_index + 1] + + face_color = colors.find_intermediate_color( + colormap[low_color_index], + colormap[low_color_index + 1], + (t - low_scale_val)/(high_scale_val - low_scale_val) + ) + + face_color = colors.convert_to_RGB_255(face_color) + face_color = colors.label_rgb(face_color) + return face_color + + +def trisurf(x, y, z, simplices, show_colorbar, edges_color, scale, + colormap=None, color_func=None, plot_edges=False, x_edge=None, + y_edge=None, z_edge=None, facecolor=None): + """ + Refer to FigureFactory.create_trisurf() for docstring + """ + # numpy import check + if not np: + raise ImportError("FigureFactory._trisurf() requires " + "numpy imported.") + points3D = np.vstack((x, y, z)).T + simplices = np.atleast_2d(simplices) + + # vertices of the surface triangles + tri_vertices = points3D[simplices] + + # Define colors for the triangle faces + if color_func is None: + # mean values of z-coordinates of triangle vertices + mean_dists = tri_vertices[:, :, 2].mean(-1) + elif isinstance(color_func, (list, np.ndarray)): + # Pre-computed list / array of values to map onto color + if len(color_func) != len(simplices): + raise ValueError("If color_func is a list/array, it must " + "be the same length as simplices.") + + # convert all colors in color_func to rgb + for index in range(len(color_func)): + if isinstance(color_func[index], str): + if '#' in color_func[index]: + foo = colors.hex_to_rgb(color_func[index]) + color_func[index] = colors.label_rgb(foo) + + if isinstance(color_func[index], tuple): + foo = colors.convert_to_RGB_255(color_func[index]) + color_func[index] = colors.label_rgb(foo) + + mean_dists = np.asarray(color_func) + else: + # apply user inputted function to calculate + # custom coloring for triangle vertices + mean_dists = [] + for triangle in tri_vertices: + dists = [] + for vertex in triangle: + dist = color_func(vertex[0], vertex[1], vertex[2]) + dists.append(dist) + mean_dists.append(np.mean(dists)) + mean_dists = np.asarray(mean_dists) + + # Check if facecolors are already strings and can be skipped + if isinstance(mean_dists[0], str): + facecolor = mean_dists + else: + min_mean_dists = np.min(mean_dists) + max_mean_dists = np.max(mean_dists) + + if facecolor is None: + facecolor = [] + for index in range(len(mean_dists)): + color = map_face2color(mean_dists[index], colormap, scale, + min_mean_dists, max_mean_dists) + facecolor.append(color) + + # Make sure facecolor is a list so output is consistent across Pythons + facecolor = np.asarray(facecolor) + ii, jj, kk = simplices.T + + triangles = graph_objs.Mesh3d(x=x, y=y, z=z, facecolor=facecolor, + i=ii, j=jj, k=kk, name='') + + mean_dists_are_numbers = not isinstance(mean_dists[0], str) + + if mean_dists_are_numbers and show_colorbar is True: + # make a colorscale from the colors + colorscale = colors.make_colorscale(colormap, scale) + colorscale = colors.convert_colorscale_to_rgb(colorscale) + + colorbar = graph_objs.Scatter3d( + x=x[:1], + y=y[:1], + z=z[:1], + mode='markers', + marker=dict( + size=0.1, + color=[min_mean_dists, max_mean_dists], + colorscale=colorscale, + showscale=True), + hoverinfo='None', + showlegend=False + ) + + # the triangle sides are not plotted + if plot_edges is False: + if mean_dists_are_numbers and show_colorbar is True: + return graph_objs.Data([triangles, colorbar]) + else: + return graph_objs.Data([triangles]) + + # define the lists x_edge, y_edge and z_edge, of x, y, resp z + # coordinates of edge end points for each triangle + # None separates data corresponding to two consecutive triangles + is_none = [ii is None for ii in [x_edge, y_edge, z_edge]] + if any(is_none): + if not all(is_none): + raise ValueError("If any (x_edge, y_edge, z_edge) is None, " + "all must be None") + else: + x_edge = [] + y_edge = [] + z_edge = [] + + # Pull indices we care about, then add a None column to separate tris + ixs_triangles = [0, 1, 2, 0] + pull_edges = tri_vertices[:, ixs_triangles, :] + x_edge_pull = np.hstack([pull_edges[:, :, 0], + np.tile(None, [pull_edges.shape[0], 1])]) + y_edge_pull = np.hstack([pull_edges[:, :, 1], + np.tile(None, [pull_edges.shape[0], 1])]) + z_edge_pull = np.hstack([pull_edges[:, :, 2], + np.tile(None, [pull_edges.shape[0], 1])]) + + # Now unravel the edges into a 1-d vector for plotting + x_edge = np.hstack([x_edge, x_edge_pull.reshape([1, -1])[0]]) + y_edge = np.hstack([y_edge, y_edge_pull.reshape([1, -1])[0]]) + z_edge = np.hstack([z_edge, z_edge_pull.reshape([1, -1])[0]]) + + if not (len(x_edge) == len(y_edge) == len(z_edge)): + raise exceptions.PlotlyError("The lengths of x_edge, y_edge and " + "z_edge are not the same.") + + # define the lines for plotting + lines = graph_objs.Scatter3d( + x=x_edge, y=y_edge, z=z_edge, mode='lines', + line=graph_objs.Line( + color=edges_color, + width=1.5 + ), + showlegend=False + ) + + if mean_dists_are_numbers and show_colorbar is True: + return graph_objs.Data([triangles, lines, colorbar]) + else: + return graph_objs.Data([triangles, lines]) + + +def create_trisurf(x, y, z, simplices, colormap=None, show_colorbar=True, + scale=None, color_func=None, title='Trisurf Plot', + plot_edges=True, showbackground=True, + backgroundcolor='rgb(230, 230, 230)', + gridcolor='rgb(255, 255, 255)', + zerolinecolor='rgb(255, 255, 255)', + edges_color='rgb(50, 50, 50)', + height=800, width=800, + aspectratio=None): + """ + Returns figure for a triangulated surface plot + + :param (array) x: data values of x in a 1D array + :param (array) y: data values of y in a 1D array + :param (array) z: data values of z in a 1D array + :param (array) simplices: an array of shape (ntri, 3) where ntri is + the number of triangles in the triangularization. Each row of the + array contains the indicies of the verticies of each triangle + :param (str|tuple|list) colormap: either a plotly scale name, an rgb + or hex color, a color tuple or a list of colors. An rgb color is + of the form 'rgb(x, y, z)' where x, y, z belong to the interval + [0, 255] and a color tuple is a tuple of the form (a, b, c) where + a, b and c belong to [0, 1]. If colormap is a list, it must + contain the valid color types aforementioned as its members + :param (bool) show_colorbar: determines if colorbar is visible + :param (list|array) scale: sets the scale values to be used if a non- + linearly interpolated colormap is desired. If left as None, a + linear interpolation between the colors will be excecuted + :param (function|list) color_func: The parameter that determines the + coloring of the surface. Takes either a function with 3 arguments + x, y, z or a list/array of color values the same length as + simplices. If None, coloring will only depend on the z axis + :param (str) title: title of the plot + :param (bool) plot_edges: determines if the triangles on the trisurf + are visible + :param (bool) showbackground: makes background in plot visible + :param (str) backgroundcolor: color of background. Takes a string of + the form 'rgb(x,y,z)' x,y,z are between 0 and 255 inclusive + :param (str) gridcolor: color of the gridlines besides the axes. Takes + a string of the form 'rgb(x,y,z)' x,y,z are between 0 and 255 + inclusive + :param (str) zerolinecolor: color of the axes. Takes a string of the + form 'rgb(x,y,z)' x,y,z are between 0 and 255 inclusive + :param (str) edges_color: color of the edges, if plot_edges is True + :param (int|float) height: the height of the plot (in pixels) + :param (int|float) width: the width of the plot (in pixels) + :param (dict) aspectratio: a dictionary of the aspect ratio values for + the x, y and z axes. 'x', 'y' and 'z' take (int|float) values + + Example 1: Sphere + ``` + # Necessary Imports for Trisurf + import numpy as np + from scipy.spatial import Delaunay + + import plotly.plotly as py + from plotly.figure_factory import create_trisurf + from plotly.graph_objs import graph_objs + + # Make data for plot + u = np.linspace(0, 2*np.pi, 20) + v = np.linspace(0, np.pi, 20) + u,v = np.meshgrid(u,v) + u = u.flatten() + v = v.flatten() + + x = np.sin(v)*np.cos(u) + y = np.sin(v)*np.sin(u) + z = np.cos(v) + + points2D = np.vstack([u,v]).T + tri = Delaunay(points2D) + simplices = tri.simplices + + # Create a figure + fig1 = create_trisurf(x=x, y=y, z=z, colormap="Rainbow", + simplices=simplices) + # Plot the data + py.iplot(fig1, filename='trisurf-plot-sphere') + ``` + + Example 2: Torus + ``` + # Necessary Imports for Trisurf + import numpy as np + from scipy.spatial import Delaunay + + import plotly.plotly as py + from plotly.figure_factory import create_trisurf + from plotly.graph_objs import graph_objs + + # Make data for plot + u = np.linspace(0, 2*np.pi, 20) + v = np.linspace(0, 2*np.pi, 20) + u,v = np.meshgrid(u,v) + u = u.flatten() + v = v.flatten() + + x = (3 + (np.cos(v)))*np.cos(u) + y = (3 + (np.cos(v)))*np.sin(u) + z = np.sin(v) + + points2D = np.vstack([u,v]).T + tri = Delaunay(points2D) + simplices = tri.simplices + + # Create a figure + fig1 = create_trisurf(x=x, y=y, z=z, colormap="Viridis", + simplices=simplices) + # Plot the data + py.iplot(fig1, filename='trisurf-plot-torus') + ``` + + Example 3: Mobius Band + ``` + # Necessary Imports for Trisurf + import numpy as np + from scipy.spatial import Delaunay + + import plotly.plotly as py + from plotly.figure_factory import create_trisurf + from plotly.graph_objs import graph_objs + + # Make data for plot + u = np.linspace(0, 2*np.pi, 24) + v = np.linspace(-1, 1, 8) + u,v = np.meshgrid(u,v) + u = u.flatten() + v = v.flatten() + + tp = 1 + 0.5*v*np.cos(u/2.) + x = tp*np.cos(u) + y = tp*np.sin(u) + z = 0.5*v*np.sin(u/2.) + + points2D = np.vstack([u,v]).T + tri = Delaunay(points2D) + simplices = tri.simplices + + # Create a figure + fig1 = create_trisurf(x=x, y=y, z=z, colormap=[(0.2, 0.4, 0.6), (1, 1, 1)], + simplices=simplices) + # Plot the data + py.iplot(fig1, filename='trisurf-plot-mobius-band') + ``` + + Example 4: Using a Custom Colormap Function with Light Cone + ``` + # Necessary Imports for Trisurf + import numpy as np + from scipy.spatial import Delaunay + + import plotly.plotly as py + from plotly.figure_factory import create_trisurf + from plotly.graph_objs import graph_objs + + # Make data for plot + u=np.linspace(-np.pi, np.pi, 30) + v=np.linspace(-np.pi, np.pi, 30) + u,v=np.meshgrid(u,v) + u=u.flatten() + v=v.flatten() + + x = u + y = u*np.cos(v) + z = u*np.sin(v) + + points2D = np.vstack([u,v]).T + tri = Delaunay(points2D) + simplices = tri.simplices + + # Define distance function + def dist_origin(x, y, z): + return np.sqrt((1.0 * x)**2 + (1.0 * y)**2 + (1.0 * z)**2) + + # Create a figure + fig1 = create_trisurf(x=x, y=y, z=z, + colormap=['#FFFFFF', '#E4FFFE', + '#A4F6F9', '#FF99FE', + '#BA52ED'], + scale=[0, 0.6, 0.71, 0.89, 1], + simplices=simplices, + color_func=dist_origin) + # Plot the data + py.iplot(fig1, filename='trisurf-plot-custom-coloring') + ``` + + Example 5: Enter color_func as a list of colors + ``` + # Necessary Imports for Trisurf + import numpy as np + from scipy.spatial import Delaunay + import random + + import plotly.plotly as py + from plotly.figure_factory import create_trisurf + from plotly.graph_objs import graph_objs + + # Make data for plot + u=np.linspace(-np.pi, np.pi, 30) + v=np.linspace(-np.pi, np.pi, 30) + u,v=np.meshgrid(u,v) + u=u.flatten() + v=v.flatten() + + x = u + y = u*np.cos(v) + z = u*np.sin(v) + + points2D = np.vstack([u,v]).T + tri = Delaunay(points2D) + simplices = tri.simplices + + + colors = [] + color_choices = ['rgb(0, 0, 0)', '#6c4774', '#d6c7dd'] + + for index in range(len(simplices)): + colors.append(random.choice(color_choices)) + + fig = create_trisurf( + x, y, z, simplices, + color_func=colors, + show_colorbar=True, + edges_color='rgb(2, 85, 180)', + title=' Modern Art' + ) + + py.iplot(fig, filename="trisurf-plot-modern-art") + ``` + """ + if aspectratio is None: + aspectratio = {'x': 1, 'y': 1, 'z': 1} + + # Validate colormap + colors.validate_colors(colormap) + colormap, scale = colors.convert_colors_to_same_type( + colormap, colortype='tuple', + return_default_colors=True, scale=scale + ) + + data1 = trisurf(x, y, z, simplices, show_colorbar=show_colorbar, + color_func=color_func, colormap=colormap, scale=scale, + edges_color=edges_color, plot_edges=plot_edges) + + axis = dict( + showbackground=showbackground, + backgroundcolor=backgroundcolor, + gridcolor=gridcolor, + zerolinecolor=zerolinecolor, + ) + layout = graph_objs.Layout( + title=title, + width=width, + height=height, + scene=graph_objs.Scene( + xaxis=graph_objs.XAxis(axis), + yaxis=graph_objs.YAxis(axis), + zaxis=graph_objs.ZAxis(axis), + aspectratio=dict( + x=aspectratio['x'], + y=aspectratio['y'], + z=aspectratio['z']), + ) + ) + + return graph_objs.Figure(data=data1, layout=layout) diff --git a/plotly/figure_factory/_violin.py b/plotly/figure_factory/_violin.py new file mode 100644 index 00000000000..d501105482e --- /dev/null +++ b/plotly/figure_factory/_violin.py @@ -0,0 +1,627 @@ +from __future__ import absolute_import + +from numbers import Number + +from plotly import exceptions, optional_imports +from plotly.figure_factory import utils +from plotly.graph_objs import graph_objs +from plotly.tools import make_subplots + +pd = optional_imports.get_module('pandas') +np = optional_imports.get_module('numpy') +scipy_stats = optional_imports.get_module('scipy.stats') + + +def calc_stats(data): + """ + Calculate statistics for use in violin plot. + """ + x = np.asarray(data, np.float) + vals_min = np.min(x) + vals_max = np.max(x) + q2 = np.percentile(x, 50, interpolation='linear') + q1 = np.percentile(x, 25, interpolation='lower') + q3 = np.percentile(x, 75, interpolation='higher') + iqr = q3 - q1 + whisker_dist = 1.5 * iqr + + # in order to prevent drawing whiskers outside the interval + # of data one defines the whisker positions as: + d1 = np.min(x[x >= (q1 - whisker_dist)]) + d2 = np.max(x[x <= (q3 + whisker_dist)]) + return { + 'min': vals_min, + 'max': vals_max, + 'q1': q1, + 'q2': q2, + 'q3': q3, + 'd1': d1, + 'd2': d2 + } + + +def make_half_violin(x, y, fillcolor='#1f77b4', linecolor='rgb(0, 0, 0)'): + """ + Produces a sideways probability distribution fig violin plot. + """ + text = ['(pdf(y), y)=(' + '{:0.2f}'.format(x[i]) + + ', ' + '{:0.2f}'.format(y[i]) + ')' + for i in range(len(x))] + + return graph_objs.Scatter( + x=x, + y=y, + mode='lines', + name='', + text=text, + fill='tonextx', + fillcolor=fillcolor, + line=graph_objs.Line(width=0.5, color=linecolor, shape='spline'), + hoverinfo='text', + opacity=0.5 + ) + + +def make_violin_rugplot(vals, pdf_max, distance, color='#1f77b4'): + """ + Returns a rugplot fig for a violin plot. + """ + return graph_objs.Scatter( + y=vals, + x=[-pdf_max-distance]*len(vals), + marker=graph_objs.Marker( + color=color, + symbol='line-ew-open' + ), + mode='markers', + name='', + showlegend=False, + hoverinfo='y' + ) + + +def make_non_outlier_interval(d1, d2): + """ + Returns the scatterplot fig of most of a violin plot. + """ + return graph_objs.Scatter( + x=[0, 0], + y=[d1, d2], + name='', + mode='lines', + line=graph_objs.Line(width=1.5, + color='rgb(0,0,0)') + ) + + +def make_quartiles(q1, q3): + """ + Makes the upper and lower quartiles for a violin plot. + """ + return graph_objs.Scatter( + x=[0, 0], + y=[q1, q3], + text=['lower-quartile: ' + '{:0.2f}'.format(q1), + 'upper-quartile: ' + '{:0.2f}'.format(q3)], + mode='lines', + line=graph_objs.Line( + width=4, + color='rgb(0,0,0)' + ), + hoverinfo='text' + ) + + +def make_median(q2): + """ + Formats the 'median' hovertext for a violin plot. + """ + return graph_objs.Scatter( + x=[0], + y=[q2], + text=['median: ' + '{:0.2f}'.format(q2)], + mode='markers', + marker=dict(symbol='square', + color='rgb(255,255,255)'), + hoverinfo='text' + ) + + +def make_XAxis(xaxis_title, xaxis_range): + """ + Makes the x-axis for a violin plot. + """ + xaxis = graph_objs.XAxis(title=xaxis_title, + range=xaxis_range, + showgrid=False, + zeroline=False, + showline=False, + mirror=False, + ticks='', + showticklabels=False) + return xaxis + + +def make_YAxis(yaxis_title): + """ + Makes the y-axis for a violin plot. + """ + yaxis = graph_objs.YAxis(title=yaxis_title, + showticklabels=True, + autorange=True, + ticklen=4, + showline=True, + zeroline=False, + showgrid=False, + mirror=False) + return yaxis + + +def violinplot(vals, fillcolor='#1f77b4', rugplot=True): + """ + Refer to FigureFactory.create_violin() for docstring. + """ + vals = np.asarray(vals, np.float) + # summary statistics + vals_min = calc_stats(vals)['min'] + vals_max = calc_stats(vals)['max'] + q1 = calc_stats(vals)['q1'] + q2 = calc_stats(vals)['q2'] + q3 = calc_stats(vals)['q3'] + d1 = calc_stats(vals)['d1'] + d2 = calc_stats(vals)['d2'] + + # kernel density estimation of pdf + pdf = scipy_stats.gaussian_kde(vals) + # grid over the data interval + xx = np.linspace(vals_min, vals_max, 100) + # evaluate the pdf at the grid xx + yy = pdf(xx) + max_pdf = np.max(yy) + # distance from the violin plot to rugplot + distance = (2.0 * max_pdf)/10 if rugplot else 0 + # range for x values in the plot + plot_xrange = [-max_pdf - distance - 0.1, max_pdf + 0.1] + plot_data = [make_half_violin(-yy, xx, fillcolor=fillcolor), + make_half_violin(yy, xx, fillcolor=fillcolor), + make_non_outlier_interval(d1, d2), + make_quartiles(q1, q3), + make_median(q2)] + if rugplot: + plot_data.append(make_violin_rugplot(vals, max_pdf, distance=distance, + color=fillcolor)) + return plot_data, plot_xrange + + +def violin_no_colorscale(data, data_header, group_header, colors, + use_colorscale, group_stats, + height, width, title): + """ + Refer to FigureFactory.create_violin() for docstring. + + Returns fig for violin plot without colorscale. + + """ + + # collect all group names + group_name = [] + for name in data[group_header]: + if name not in group_name: + group_name.append(name) + group_name.sort() + + gb = data.groupby([group_header]) + L = len(group_name) + + fig = make_subplots(rows=1, cols=L, + shared_yaxes=True, + horizontal_spacing=0.025, + print_grid=False) + color_index = 0 + for k, gr in enumerate(group_name): + vals = np.asarray(gb.get_group(gr)[data_header], np.float) + if color_index >= len(colors): + color_index = 0 + plot_data, plot_xrange = violinplot(vals, + fillcolor=colors[color_index]) + layout = graph_objs.Layout() + + for item in plot_data: + fig.append_trace(item, 1, k + 1) + color_index += 1 + + # add violin plot labels + fig['layout'].update( + {'xaxis{}'.format(k + 1): make_XAxis(group_name[k], plot_xrange)} + ) + + # set the sharey axis style + fig['layout'].update({'yaxis{}'.format(1): make_YAxis('')}) + fig['layout'].update( + title=title, + showlegend=False, + hovermode='closest', + autosize=False, + height=height, + width=width + ) + + return fig + + +def violin_colorscale(data, data_header, group_header, colors, use_colorscale, + group_stats, height, width, title): + """ + Refer to FigureFactory.create_violin() for docstring. + + Returns fig for violin plot with colorscale. + + """ + + # collect all group names + group_name = [] + for name in data[group_header]: + if name not in group_name: + group_name.append(name) + group_name.sort() + + # make sure all group names are keys in group_stats + for group in group_name: + if group not in group_stats: + raise exceptions.PlotlyError("All values/groups in the index " + "column must be represented " + "as a key in group_stats.") + + gb = data.groupby([group_header]) + L = len(group_name) + + fig = make_subplots(rows=1, cols=L, + shared_yaxes=True, + horizontal_spacing=0.025, + print_grid=False) + + # prepare low and high color for colorscale + lowcolor = utils.color_parser(colors[0], utils.unlabel_rgb) + highcolor = utils.color_parser(colors[1], utils.unlabel_rgb) + + # find min and max values in group_stats + group_stats_values = [] + for key in group_stats: + group_stats_values.append(group_stats[key]) + + max_value = max(group_stats_values) + min_value = min(group_stats_values) + + for k, gr in enumerate(group_name): + vals = np.asarray(gb.get_group(gr)[data_header], np.float) + + # find intermediate color from colorscale + intermed = (group_stats[gr] - min_value) / (max_value - min_value) + intermed_color = utils.find_intermediate_color( + lowcolor, highcolor, intermed + ) + + plot_data, plot_xrange = violinplot( + vals, + fillcolor='rgb{}'.format(intermed_color) + ) + layout = graph_objs.Layout() + + for item in plot_data: + fig.append_trace(item, 1, k + 1) + fig['layout'].update( + {'xaxis{}'.format(k + 1): make_XAxis(group_name[k], plot_xrange)} + ) + # add colorbar to plot + trace_dummy = graph_objs.Scatter( + x=[0], + y=[0], + mode='markers', + marker=dict( + size=2, + cmin=min_value, + cmax=max_value, + colorscale=[[0, colors[0]], + [1, colors[1]]], + showscale=True), + showlegend=False, + ) + fig.append_trace(trace_dummy, 1, L) + + # set the sharey axis style + fig['layout'].update({'yaxis{}'.format(1): make_YAxis('')}) + fig['layout'].update( + title=title, + showlegend=False, + hovermode='closest', + autosize=False, + height=height, + width=width + ) + + return fig + + +def violin_dict(data, data_header, group_header, colors, use_colorscale, + group_stats, height, width, title): + """ + Refer to FigureFactory.create_violin() for docstring. + + Returns fig for violin plot without colorscale. + + """ + + # collect all group names + group_name = [] + for name in data[group_header]: + if name not in group_name: + group_name.append(name) + group_name.sort() + + # check if all group names appear in colors dict + for group in group_name: + if group not in colors: + raise exceptions.PlotlyError("If colors is a dictionary, all " + "the group names must appear as " + "keys in colors.") + + gb = data.groupby([group_header]) + L = len(group_name) + + fig = make_subplots(rows=1, cols=L, + shared_yaxes=True, + horizontal_spacing=0.025, + print_grid=False) + + for k, gr in enumerate(group_name): + vals = np.asarray(gb.get_group(gr)[data_header], np.float) + plot_data, plot_xrange = violinplot(vals, fillcolor=colors[gr]) + layout = graph_objs.Layout() + + for item in plot_data: + fig.append_trace(item, 1, k + 1) + + # add violin plot labels + fig['layout'].update( + {'xaxis{}'.format(k + 1): make_XAxis(group_name[k], plot_xrange)} + ) + + # set the sharey axis style + fig['layout'].update({'yaxis{}'.format(1): make_YAxis('')}) + fig['layout'].update( + title=title, + showlegend=False, + hovermode='closest', + autosize=False, + height=height, + width=width + ) + + return fig + + +def create_violin(data, data_header=None, group_header=None, colors=None, + use_colorscale=False, group_stats=None, height=450, + width=600, title='Violin and Rug Plot'): + """ + Returns figure for a violin plot + + :param (list|array) data: accepts either a list of numerical values, + a list of dictionaries all with identical keys and at least one + column of numeric values, or a pandas dataframe with at least one + column of numbers + :param (str) data_header: the header of the data column to be used + from an inputted pandas dataframe. Not applicable if 'data' is + a list of numeric values + :param (str) group_header: applicable if grouping data by a variable. + 'group_header' must be set to the name of the grouping variable. + :param (str|tuple|list|dict) colors: either a plotly scale name, + an rgb or hex color, a color tuple, a list of colors or a + dictionary. An rgb color is of the form 'rgb(x, y, z)' where + x, y and z belong to the interval [0, 255] and a color tuple is a + tuple of the form (a, b, c) where a, b and c belong to [0, 1]. + If colors is a list, it must contain valid color types as its + members. + :param (bool) use_colorscale: Only applicable if grouping by another + variable. Will implement a colorscale based on the first 2 colors + of param colors. This means colors must be a list with at least 2 + colors in it (Plotly colorscales are accepted since they map to a + list of two rgb colors) + :param (dict) group_stats: a dictioanry where each key is a unique + value from the group_header column in data. Each value must be a + number and will be used to color the violin plots if a colorscale + is being used + :param (float) height: the height of the violin plot + :param (float) width: the width of the violin plot + :param (str) title: the title of the violin plot + + Example 1: Single Violin Plot + ``` + import plotly.plotly as py + from plotly.figure_factory import create_violin + from plotly.graph_objs import graph_objs + + import numpy as np + from scipy import stats + + # create list of random values + data_list = np.random.randn(100) + data_list.tolist() + + # create violin fig + fig = create_violin(data_list, colors='#604d9e') + + # plot + py.iplot(fig, filename='Violin Plot') + ``` + + Example 2: Multiple Violin Plots with Qualitative Coloring + ``` + import plotly.plotly as py + from plotly.figure_factory import create_violin + from plotly.graph_objs import graph_objs + + import numpy as np + import pandas as pd + from scipy import stats + + # create dataframe + np.random.seed(619517) + Nr=250 + y = np.random.randn(Nr) + gr = np.random.choice(list("ABCDE"), Nr) + norm_params=[(0, 1.2), (0.7, 1), (-0.5, 1.4), (0.3, 1), (0.8, 0.9)] + + for i, letter in enumerate("ABCDE"): + y[gr == letter] *=norm_params[i][1]+ norm_params[i][0] + df = pd.DataFrame(dict(Score=y, Group=gr)) + + # create violin fig + fig = create_violin(df, data_header='Score', group_header='Group', + height=600, width=1000) + + # plot + py.iplot(fig, filename='Violin Plot with Coloring') + ``` + + Example 3: Violin Plots with Colorscale + ``` + import plotly.plotly as py + from plotly.figure_factory import create_violin + from plotly.graph_objs import graph_objs + + import numpy as np + import pandas as pd + from scipy import stats + + # create dataframe + np.random.seed(619517) + Nr=250 + y = np.random.randn(Nr) + gr = np.random.choice(list("ABCDE"), Nr) + norm_params=[(0, 1.2), (0.7, 1), (-0.5, 1.4), (0.3, 1), (0.8, 0.9)] + + for i, letter in enumerate("ABCDE"): + y[gr == letter] *=norm_params[i][1]+ norm_params[i][0] + df = pd.DataFrame(dict(Score=y, Group=gr)) + + # define header params + data_header = 'Score' + group_header = 'Group' + + # make groupby object with pandas + group_stats = {} + groupby_data = df.groupby([group_header]) + + for group in "ABCDE": + data_from_group = groupby_data.get_group(group)[data_header] + # take a stat of the grouped data + stat = np.median(data_from_group) + # add to dictionary + group_stats[group] = stat + + # create violin fig + fig = create_violin(df, data_header='Score', group_header='Group', + height=600, width=1000, use_colorscale=True, + group_stats=group_stats) + + # plot + py.iplot(fig, filename='Violin Plot with Colorscale') + ``` + """ + + # Validate colors + if isinstance(colors, dict): + valid_colors = utils.validate_colors_dict(colors, 'rgb') + else: + valid_colors = utils.validate_colors(colors, 'rgb') + + # validate data and choose plot type + if group_header is None: + if isinstance(data, list): + if len(data) <= 0: + raise exceptions.PlotlyError("If data is a list, it must be " + "nonempty and contain either " + "numbers or dictionaries.") + + if not all(isinstance(element, Number) for element in data): + raise exceptions.PlotlyError("If data is a list, it must " + "contain only numbers.") + + if pd and isinstance(data, pd.core.frame.DataFrame): + if data_header is None: + raise exceptions.PlotlyError("data_header must be the " + "column name with the " + "desired numeric data for " + "the violin plot.") + + data = data[data_header].values.tolist() + + # call the plotting functions + plot_data, plot_xrange = violinplot(data, fillcolor=valid_colors[0]) + + layout = graph_objs.Layout( + title=title, + autosize=False, + font=graph_objs.Font(size=11), + height=height, + showlegend=False, + width=width, + xaxis=make_XAxis('', plot_xrange), + yaxis=make_YAxis(''), + hovermode='closest' + ) + layout['yaxis'].update(dict(showline=False, + showticklabels=False, + ticks='')) + + fig = graph_objs.Figure(data=graph_objs.Data(plot_data), + layout=layout) + + return fig + + else: + if not isinstance(data, pd.core.frame.DataFrame): + raise exceptions.PlotlyError("Error. You must use a pandas " + "DataFrame if you are using a " + "group header.") + + if data_header is None: + raise exceptions.PlotlyError("data_header must be the column " + "name with the desired numeric " + "data for the violin plot.") + + if use_colorscale is False: + if isinstance(valid_colors, dict): + # validate colors dict choice below + fig = violin_dict( + data, data_header, group_header, valid_colors, + use_colorscale, group_stats, height, width, title + ) + return fig + else: + fig = violin_no_colorscale( + data, data_header, group_header, valid_colors, + use_colorscale, group_stats, height, width, title + ) + return fig + else: + if isinstance(valid_colors, dict): + raise exceptions.PlotlyError("The colors param cannot be " + "a dictionary if you are " + "using a colorscale.") + + if len(valid_colors) < 2: + raise exceptions.PlotlyError("colors must be a list with " + "at least 2 colors. A " + "Plotly scale is allowed.") + + if not isinstance(group_stats, dict): + raise exceptions.PlotlyError("Your group_stats param " + "must be a dictionary.") + + fig = violin_colorscale( + data, data_header, group_header, valid_colors, + use_colorscale, group_stats, height, width, title + ) + return fig diff --git a/plotly/figure_factory/utils.py b/plotly/figure_factory/utils.py new file mode 100644 index 00000000000..114cd2be5d5 --- /dev/null +++ b/plotly/figure_factory/utils.py @@ -0,0 +1,376 @@ +from __future__ import absolute_import + +import decimal + +from plotly import exceptions + +DEFAULT_PLOTLY_COLORS = ['rgb(31, 119, 180)', 'rgb(255, 127, 14)', + 'rgb(44, 160, 44)', 'rgb(214, 39, 40)', + 'rgb(148, 103, 189)', 'rgb(140, 86, 75)', + 'rgb(227, 119, 194)', 'rgb(127, 127, 127)', + 'rgb(188, 189, 34)', 'rgb(23, 190, 207)'] + +PLOTLY_SCALES = {'Greys': ['rgb(0,0,0)', 'rgb(255,255,255)'], + 'YlGnBu': ['rgb(8,29,88)', 'rgb(255,255,217)'], + 'Greens': ['rgb(0,68,27)', 'rgb(247,252,245)'], + 'YlOrRd': ['rgb(128,0,38)', 'rgb(255,255,204)'], + 'Bluered': ['rgb(0,0,255)', 'rgb(255,0,0)'], + 'RdBu': ['rgb(5,10,172)', 'rgb(178,10,28)'], + 'Reds': ['rgb(220,220,220)', 'rgb(178,10,28)'], + 'Blues': ['rgb(5,10,172)', 'rgb(220,220,220)'], + 'Picnic': ['rgb(0,0,255)', 'rgb(255,0,0)'], + 'Rainbow': ['rgb(150,0,90)', 'rgb(255,0,0)'], + 'Portland': ['rgb(12,51,131)', 'rgb(217,30,30)'], + 'Jet': ['rgb(0,0,131)', 'rgb(128,0,0)'], + 'Hot': ['rgb(0,0,0)', 'rgb(255,255,255)'], + 'Blackbody': ['rgb(0,0,0)', 'rgb(160,200,255)'], + 'Earth': ['rgb(0,0,130)', 'rgb(255,255,255)'], + 'Electric': ['rgb(0,0,0)', 'rgb(255,250,220)'], + 'Viridis': ['rgb(68,1,84)', 'rgb(253,231,37)']} + + +def validate_index(index_vals): + """ + Validates if a list contains all numbers or all strings + + :raises: (PlotlyError) If there are any two items in the list whose + types differ + """ + from numbers import Number + if isinstance(index_vals[0], Number): + if not all(isinstance(item, Number) for item in index_vals): + raise exceptions.PlotlyError("Error in indexing column. " + "Make sure all entries of each " + "column are all numbers or " + "all strings.") + + elif isinstance(index_vals[0], str): + if not all(isinstance(item, str) for item in index_vals): + raise exceptions.PlotlyError("Error in indexing column. " + "Make sure all entries of each " + "column are all numbers or " + "all strings.") + + +def validate_dataframe(array): + """ + Validates all strings or numbers in each dataframe column + + :raises: (PlotlyError) If there are any two items in any list whose + types differ + """ + from numbers import Number + for vector in array: + if isinstance(vector[0], Number): + if not all(isinstance(item, Number) for item in vector): + raise exceptions.PlotlyError("Error in dataframe. " + "Make sure all entries of " + "each column are either " + "numbers or strings.") + elif isinstance(vector[0], str): + if not all(isinstance(item, str) for item in vector): + raise exceptions.PlotlyError("Error in dataframe. " + "Make sure all entries of " + "each column are either " + "numbers or strings.") + + +def validate_equal_length(*args): + """ + Validates that data lists or ndarrays are the same length. + + :raises: (PlotlyError) If any data lists are not the same length. + """ + length = len(args[0]) + if any(len(lst) != length for lst in args): + raise exceptions.PlotlyError("Oops! Your data lists or ndarrays " + "should be the same length.") + + +def validate_positive_scalars(**kwargs): + """ + Validates that all values given in key/val pairs are positive. + + Accepts kwargs to improve Exception messages. + + :raises: (PlotlyError) If any value is < 0 or raises. + """ + for key, val in kwargs.items(): + try: + if val <= 0: + raise ValueError('{} must be > 0, got {}'.format(key, val)) + except TypeError: + raise exceptions.PlotlyError('{} must be a number, got {}' + .format(key, val)) + + +def flatten(array): + """ + Uses list comprehension to flatten array + + :param (array): An iterable to flatten + :raises (PlotlyError): If iterable is not nested. + :rtype (list): The flattened list. + """ + try: + return [item for sublist in array for item in sublist] + except TypeError: + raise exceptions.PlotlyError("Your data array could not be " + "flattened! Make sure your data is " + "entered as lists or ndarrays!") + + +def find_intermediate_color(lowcolor, highcolor, intermed): + """ + Returns the color at a given distance between two colors + + This function takes two color tuples, where each element is between 0 + and 1, along with a value 0 < intermed < 1 and returns a color that is + intermed-percent from lowcolor to highcolor + + """ + diff_0 = float(highcolor[0] - lowcolor[0]) + diff_1 = float(highcolor[1] - lowcolor[1]) + diff_2 = float(highcolor[2] - lowcolor[2]) + + return (lowcolor[0] + intermed * diff_0, + lowcolor[1] + intermed * diff_1, + lowcolor[2] + intermed * diff_2) + + +def n_colors(lowcolor, highcolor, n_colors): + """ + Splits a low and high color into a list of n_colors colors in it + + Accepts two color tuples and returns a list of n_colors colors + which form the intermediate colors between lowcolor and highcolor + from linearly interpolating through RGB space + + """ + diff_0 = float(highcolor[0] - lowcolor[0]) + incr_0 = diff_0/(n_colors - 1) + diff_1 = float(highcolor[1] - lowcolor[1]) + incr_1 = diff_1/(n_colors - 1) + diff_2 = float(highcolor[2] - lowcolor[2]) + incr_2 = diff_2/(n_colors - 1) + color_tuples = [] + + for index in range(n_colors): + new_tuple = (lowcolor[0] + (index * incr_0), + lowcolor[1] + (index * incr_1), + lowcolor[2] + (index * incr_2)) + color_tuples.append(new_tuple) + + return color_tuples + + +def label_rgb(colors): + """ + Takes tuple (a, b, c) and returns an rgb color 'rgb(a, b, c)' + """ + return ('rgb(%s, %s, %s)' % (colors[0], colors[1], colors[2])) + + +def unlabel_rgb(colors): + """ + Takes rgb color(s) 'rgb(a, b, c)' and returns tuple(s) (a, b, c) + + This function takes either an 'rgb(a, b, c)' color or a list of + such colors and returns the color tuples in tuple(s) (a, b, c) + + """ + str_vals = '' + for index in range(len(colors)): + try: + float(colors[index]) + str_vals = str_vals + colors[index] + except ValueError: + if colors[index] == ',' or colors[index] == '.': + str_vals = str_vals + colors[index] + + str_vals = str_vals + ',' + numbers = [] + str_num = '' + for char in str_vals: + if char != ',': + str_num = str_num + char + else: + numbers.append(float(str_num)) + str_num = '' + return (numbers[0], numbers[1], numbers[2]) + + +def unconvert_from_RGB_255(colors): + """ + Return a tuple where each element gets divided by 255 + + Takes a (list of) color tuple(s) where each element is between 0 and + 255. Returns the same tuples where each tuple element is normalized to + a value between 0 and 1 + + """ + return (colors[0]/(255.0), + colors[1]/(255.0), + colors[2]/(255.0)) + + +def convert_to_RGB_255(colors): + """ + Multiplies each element of a triplet by 255 + + Each coordinate of the color tuple is rounded to the nearest float and + then is turned into an integer. If a number is of the form x.5, then + if x is odd, the number rounds up to (x+1). Otherwise, it rounds down + to just x. This is the way rounding works in Python 3 and in current + statistical analysis to avoid rounding bias + """ + rgb_components = [] + + for component in colors: + rounded_num = decimal.Decimal(str(component*255.0)).quantize( + decimal.Decimal('1'), rounding=decimal.ROUND_HALF_EVEN + ) + # convert rounded number to an integer from 'Decimal' form + rounded_num = int(rounded_num) + rgb_components.append(rounded_num) + + return (rgb_components[0], rgb_components[1], rgb_components[2]) + + +def hex_to_rgb(value): + """ + Calculates rgb values from a hex color code. + + :param (string) value: Hex color string + + :rtype (tuple) (r_value, g_value, b_value): tuple of rgb values + """ + value = value.lstrip('#') + hex_total_length = len(value) + rgb_section_length = hex_total_length // 3 + return tuple(int(value[i:i + rgb_section_length], 16) + for i in range(0, hex_total_length, rgb_section_length)) + + +def color_parser(colors, function): + """ + Takes color(s) and a function and applies the function on the color(s) + + In particular, this function identifies whether the given color object + is an iterable or not and applies the given color-parsing function to + the color or iterable of colors. If given an iterable, it will only be + able to work with it if all items in the iterable are of the same type + - rgb string, hex string or tuple + + """ + from numbers import Number + if isinstance(colors, str): + return function(colors) + + if isinstance(colors, tuple) and isinstance(colors[0], Number): + return function(colors) + + if hasattr(colors, '__iter__'): + if isinstance(colors, tuple): + new_color_tuple = tuple(function(item) for item in colors) + return new_color_tuple + + else: + new_color_list = [function(item) for item in colors] + return new_color_list + + +def validate_colors(colors, colortype='tuple'): + """ + Validates color(s) and returns a list of color(s) of a specified type + """ + from numbers import Number + if colors is None: + colors = DEFAULT_PLOTLY_COLORS + + if isinstance(colors, str): + if colors in PLOTLY_SCALES: + colors = PLOTLY_SCALES[colors] + elif 'rgb' in colors or '#' in colors: + colors = [colors] + else: + raise exceptions.PlotlyError( + "If your colors variable is a string, it must be a " + "Plotly scale, an rgb color or a hex color.") + + elif isinstance(colors, tuple): + if isinstance(colors[0], Number): + colors = [colors] + else: + colors = list(colors) + + # convert color elements in list to tuple color + for j, each_color in enumerate(colors): + if 'rgb' in each_color: + each_color = color_parser(each_color, unlabel_rgb) + for value in each_color: + if value > 255.0: + raise exceptions.PlotlyError( + "Whoops! The elements in your rgb colors " + "tuples cannot exceed 255.0." + ) + each_color = color_parser(each_color, unconvert_from_RGB_255) + colors[j] = each_color + + if '#' in each_color: + each_color = color_parser(each_color, hex_to_rgb) + each_color = color_parser(each_color, unconvert_from_RGB_255) + + colors[j] = each_color + + if isinstance(each_color, tuple): + for value in each_color: + if value > 1.0: + raise exceptions.PlotlyError( + "Whoops! The elements in your colors tuples " + "cannot exceed 1.0." + ) + colors[j] = each_color + + if colortype == 'rgb': + for j, each_color in enumerate(colors): + rgb_color = color_parser(each_color, convert_to_RGB_255) + colors[j] = color_parser(rgb_color, label_rgb) + + return colors + + +def validate_colors_dict(colors, colortype='tuple'): + """ + Validates dictioanry of color(s) + """ + # validate each color element in the dictionary + for key in colors: + if 'rgb' in colors[key]: + colors[key] = color_parser(colors[key], unlabel_rgb) + for value in colors[key]: + if value > 255.0: + raise exceptions.PlotlyError( + "Whoops! The elements in your rgb colors " + "tuples cannot exceed 255.0." + ) + colors[key] = color_parser(colors[key], unconvert_from_RGB_255) + + if '#' in colors[key]: + colors[key] = color_parser(colors[key], hex_to_rgb) + colors[key] = color_parser(colors[key], unconvert_from_RGB_255) + + if isinstance(colors[key], tuple): + for value in colors[key]: + if value > 1.0: + raise exceptions.PlotlyError( + "Whoops! The elements in your colors tuples " + "cannot exceed 1.0." + ) + + if colortype == 'rgb': + for key in colors: + colors[key] = color_parser(colors[key], convert_to_RGB_255) + colors[key] = color_parser(colors[key], label_rgb) + + return colors diff --git a/plotly/grid_objs/grid_objs.py b/plotly/grid_objs/grid_objs.py index 782e3dc2fe0..128d3bf90a1 100644 --- a/plotly/grid_objs/grid_objs.py +++ b/plotly/grid_objs/grid_objs.py @@ -5,9 +5,10 @@ """ from __future__ import absolute_import -import json from collections import MutableSequence +from requests.compat import json as _json + from plotly import exceptions, utils __all__ = None @@ -66,7 +67,7 @@ def __init__(self, data, name): def __str__(self): max_chars = 10 - jdata = json.dumps(self.data, cls=utils.PlotlyJSONEncoder) + jdata = _json.dumps(self.data, cls=utils.PlotlyJSONEncoder) if len(jdata) > max_chars: data_string = jdata[:max_chars] + "...]" else: diff --git a/plotly/offline/offline.py b/plotly/offline/offline.py index 965b5af09e0..2228d52b00d 100644 --- a/plotly/offline/offline.py +++ b/plotly/offline/offline.py @@ -5,7 +5,6 @@ """ from __future__ import absolute_import -import json import os import uuid import warnings @@ -13,22 +12,15 @@ import time import webbrowser +from requests.compat import json as _json + import plotly -from plotly import tools, utils +from plotly import optional_imports, tools, utils from plotly.exceptions import PlotlyError -try: - import IPython - from IPython.display import HTML, display - _ipython_imported = True -except ImportError: - _ipython_imported = False - -try: - import matplotlib - _matplotlib_imported = True -except ImportError: - _matplotlib_imported = False +ipython = optional_imports.get_module('IPython') +ipython_display = optional_imports.get_module('IPython.display') +matplotlib = optional_imports.get_module('matplotlib') __PLOTLY_OFFLINE_INITIALIZED = False @@ -111,7 +103,7 @@ def init_notebook_mode(connected=False): your notebook, resulting in much larger notebook sizes compared to the case where `connected=True`. """ - if not _ipython_imported: + if not ipython: raise ImportError('`iplot` can only run inside an IPython Notebook.') global __PLOTLY_OFFLINE_INITIALIZED @@ -148,7 +140,7 @@ def init_notebook_mode(connected=False): '' '').format(script=get_plotlyjs()) - display(HTML(script_inject)) + ipython_display.display(ipython_display.HTML(script_inject)) __PLOTLY_OFFLINE_INITIALIZED = True @@ -183,10 +175,12 @@ def _plot_html(figure_or_data, config, validate, default_width, height = str(height) + 'px' plotdivid = uuid.uuid4() - jdata = json.dumps(figure.get('data', []), cls=utils.PlotlyJSONEncoder) - jlayout = json.dumps(figure.get('layout', {}), cls=utils.PlotlyJSONEncoder) + jdata = _json.dumps(figure.get('data', []), cls=utils.PlotlyJSONEncoder) + jlayout = _json.dumps(figure.get('layout', {}), + cls=utils.PlotlyJSONEncoder) if 'frames' in figure_or_data: - jframes = json.dumps(figure.get('frames', {}), cls=utils.PlotlyJSONEncoder) + jframes = _json.dumps(figure.get('frames', {}), + cls=utils.PlotlyJSONEncoder) configkeys = ( 'editable', @@ -211,7 +205,7 @@ def _plot_html(figure_or_data, config, validate, default_width, ) config_clean = dict((k, config[k]) for k in configkeys if k in config) - jconfig = json.dumps(config_clean) + jconfig = _json.dumps(config_clean) # TODO: The get_config 'source of truth' should # really be somewhere other than plotly.plotly @@ -330,7 +324,7 @@ def iplot(figure_or_data, show_link=True, link_text='Export to plot.ly', 'plotly.offline.init_notebook_mode() ' '# run at the start of every ipython notebook', ])) - if not tools._ipython_imported: + if not ipython: raise ImportError('`iplot` can only run inside an IPython Notebook.') config = {} @@ -341,7 +335,7 @@ def iplot(figure_or_data, show_link=True, link_text='Export to plot.ly', figure_or_data, config, validate, '100%', 525, True ) - display(HTML(plot_html)) + ipython_display.display(ipython_display.HTML(plot_html)) if image: if image not in __IMAGE_FORMATS: @@ -357,7 +351,7 @@ def iplot(figure_or_data, show_link=True, link_text='Export to plot.ly', # allow time for the plot to draw time.sleep(1) # inject code to download an image of the plot - display(HTML(script)) + ipython_display.display(ipython_display.HTML(script)) def plot(figure_or_data, show_link=True, link_text='Export to plot.ly', @@ -687,7 +681,7 @@ def enable_mpl_offline(resize=False, strip_style=False, """ init_notebook_mode() - ip = IPython.core.getipython.get_ipython() + ip = ipython.core.getipython.get_ipython() formatter = ip.display_formatter.formatters['text/html'] formatter.for_type(matplotlib.figure.Figure, lambda fig: iplot_mpl(fig, resize, strip_style, verbose, diff --git a/plotly/optional_imports.py b/plotly/optional_imports.py new file mode 100644 index 00000000000..7e9ba805b42 --- /dev/null +++ b/plotly/optional_imports.py @@ -0,0 +1,25 @@ +""" +Stand-alone module to provide information about whether optional deps exist. + +""" +from __future__ import absolute_import + +from importlib import import_module + +_not_importable = set() + + +def get_module(name): + """ + Return module or None. Absolute import is required. + + :param (str) name: Dot-separated module path. E.g., 'scipy.stats'. + :raise: (ImportError) Only when exc_msg is defined. + :return: (module|None) If import succeeds, the module will be returned. + + """ + if name not in _not_importable: + try: + return import_module(name) + except ImportError: + _not_importable.add(name) diff --git a/plotly/plotly/plotly.py b/plotly/plotly/plotly.py index 756fc190000..a5bf559df71 100644 --- a/plotly/plotly/plotly.py +++ b/plotly/plotly/plotly.py @@ -16,25 +16,22 @@ """ from __future__ import absolute_import -import base64 import copy -import json import os import warnings -import requests import six import six.moves +from requests.compat import json as _json -from requests.auth import HTTPBasicAuth - -from plotly import exceptions, tools, utils, version, files +from plotly import exceptions, files, session, tools, utils +from plotly.api import v1, v2 from plotly.plotly import chunked_requests -from plotly.session import (sign_in, update_session_plot_options, - get_session_plot_options, get_session_credentials, - get_session_config) from plotly.grid_objs import Grid, Column +# This is imported like this for backwards compat. Careful if changing. +from plotly.config import get_config, get_credentials + __all__ = None DEFAULT_PLOT_OPTIONS = { @@ -51,34 +48,16 @@ # don't break backwards compatibility -sign_in = sign_in -update_plot_options = update_session_plot_options - - -def get_credentials(): - """Returns the credentials that will be sent to plotly.""" - credentials = tools.get_credentials_file() - session_credentials = get_session_credentials() - for credentials_key in credentials: - - # checking for not false, but truthy value here is the desired behavior - session_value = session_credentials.get(credentials_key) - if session_value is False or session_value: - credentials[credentials_key] = session_value - return credentials - - -def get_config(): - """Returns either module config or file config.""" - config = tools.get_config_file() - session_config = get_session_config() - for config_key in config: +def sign_in(username, api_key, **kwargs): + session.sign_in(username, api_key, **kwargs) + try: + # The only way this can succeed is if the user can be authenticated + # with the given, username, api_key, and plotly_api_domain. + v2.users.current() + except exceptions.PlotlyRequestError: + raise exceptions.PlotlyError('Sign in failed.') - # checking for not false, but truthy value here is the desired behavior - session_value = session_config.get(config_key) - if session_value is False or session_value: - config[config_key] = session_value - return config +update_plot_options = session.update_session_plot_options def _plot_option_logic(plot_options_from_call_signature): @@ -93,7 +72,7 @@ def _plot_option_logic(plot_options_from_call_signature): """ default_plot_options = copy.deepcopy(DEFAULT_PLOT_OPTIONS) file_options = tools.get_config_file() - session_options = get_session_plot_options() + session_options = session.get_session_plot_options() plot_options_from_call_signature = copy.deepcopy(plot_options_from_call_signature) # Validate options and fill in defaults w world_readable and sharing @@ -238,15 +217,23 @@ def plot(figure_or_data, validate=True, **plot_options): pass plot_options = _plot_option_logic(plot_options) - res = _send_to_plotly(figure, **plot_options) - if res['error'] == '': - if plot_options['auto_open']: - _open_url(res['url']) + fig = tools._replace_newline(figure) # does not mutate figure + data = fig.get('data', []) + plot_options['layout'] = fig.get('layout', {}) + response = v1.clientresp(data, **plot_options) - return res['url'] - else: - raise exceptions.PlotlyAccountError(res['error']) + # Check if the url needs a secret key + url = response.json()['url'] + if plot_options['sharing'] == 'secret': + if 'share_key=' not in url: + # add_share_key_to_url updates the url to include the share_key + url = add_share_key_to_url(url) + + if plot_options['auto_open']: + _open_url(url) + + return url def iplot_mpl(fig, resize=True, strip_style=False, update=None, @@ -316,6 +303,64 @@ def plot_mpl(fig, resize=True, strip_style=False, update=None, **plot_options): return plot(fig, **plot_options) +def _swap_keys(obj, key1, key2): + """Swap obj[key1] with obj[key2]""" + val1, val2 = None, None + try: + val2 = obj.pop(key1) + except KeyError: + pass + try: + val1 = obj.pop(key2) + except KeyError: + pass + if val2 is not None: + obj[key2] = val2 + if val1 is not None: + obj[key1] = val1 + + +def _swap_xy_data(data_obj): + """Swap x and y data and references""" + swaps = [('x', 'y'), + ('x0', 'y0'), + ('dx', 'dy'), + ('xbins', 'ybins'), + ('nbinsx', 'nbinsy'), + ('autobinx', 'autobiny'), + ('error_x', 'error_y')] + for swap in swaps: + _swap_keys(data_obj, swap[0], swap[1]) + try: + rows = len(data_obj['z']) + cols = len(data_obj['z'][0]) + for row in data_obj['z']: + if len(row) != cols: + raise TypeError + + # if we can't do transpose, we hit an exception before here + z = data_obj.pop('z') + data_obj['z'] = [[0 for rrr in range(rows)] for ccc in range(cols)] + for iii in range(rows): + for jjj in range(cols): + data_obj['z'][jjj][iii] = z[iii][jjj] + except (KeyError, TypeError, IndexError) as err: + warn = False + try: + if data_obj['z'] is not None: + warn = True + if len(data_obj['z']) == 0: + warn = False + except (KeyError, TypeError): + pass + if warn: + warnings.warn( + "Data in this file required an 'xy' swap but the 'z' matrix " + "in one of the data objects could not be transposed. Here's " + "why:\n\n{}".format(repr(err)) + ) + + def get_figure(file_owner_or_url, file_id=None, raw=False): """Returns a JSON figure representation for the specified file @@ -363,15 +408,6 @@ def get_figure(file_owner_or_url, file_id=None, raw=False): file_id = url.replace(head, "").split('/')[1] else: file_owner = file_owner_or_url - resource = "/apigetfile/{username}/{file_id}".format(username=file_owner, - file_id=file_id) - credentials = get_credentials() - validate_credentials(credentials) - username, api_key = credentials['username'], credentials['api_key'] - headers = {'plotly-username': username, - 'plotly-apikey': api_key, - 'plotly-version': version.__version__, - 'plotly-platform': 'python'} try: int(file_id) except ValueError: @@ -386,28 +422,49 @@ def get_figure(file_owner_or_url, file_id=None, raw=False): "The 'file_id' argument must be a non-negative number." ) - response = requests.get(plotly_rest_url + resource, - headers=headers, - verify=get_config()['plotly_ssl_verification']) - if response.status_code == 200: - if six.PY3: - content = json.loads(response.content.decode('utf-8')) - else: - content = json.loads(response.content) - response_payload = content['payload'] - figure = response_payload['figure'] - utils.decode_unicode(figure) - if raw: - return figure - else: - return tools.get_valid_graph_obj(figure, obj_type='Figure') - else: + fid = '{}:{}'.format(file_owner, file_id) + response = v2.plots.content(fid, inline_data=True) + figure = response.json() + + # Fix 'histogramx', 'histogramy', and 'bardir' stuff + for index, entry in enumerate(figure['data']): try: - content = json.loads(response.content) - raise exceptions.PlotlyError(content) - except: - raise exceptions.PlotlyError( - "There was an error retrieving this file") + # Use xbins to bin data in x, and ybins to bin data in y + if all((entry['type'] == 'histogramy', 'xbins' in entry, + 'ybins' not in entry)): + entry['ybins'] = entry.pop('xbins') + + # Convert bardir to orientation, and put the data into the axes + # it's eventually going to be used with + if entry['type'] in ['histogramx', 'histogramy']: + entry['type'] = 'histogram' + if 'bardir' in entry: + entry['orientation'] = entry.pop('bardir') + if entry['type'] == 'bar': + if entry['orientation'] == 'h': + _swap_xy_data(entry) + if entry['type'] == 'histogram': + if ('x' in entry) and ('y' not in entry): + if entry['orientation'] == 'h': + _swap_xy_data(entry) + del entry['orientation'] + if ('y' in entry) and ('x' not in entry): + if entry['orientation'] == 'v': + _swap_xy_data(entry) + del entry['orientation'] + figure['data'][index] = entry + except KeyError: + pass + + # Remove stream dictionary if found in a data trace + # (it has private tokens in there we need to hide!) + for index, entry in enumerate(figure['data']): + if 'stream' in entry: + del figure['data'][index]['stream'] + + if raw: + return figure + return tools.get_valid_graph_obj(figure, obj_type='Figure') @utils.template_doc(**tools.get_config_file()) @@ -592,7 +649,7 @@ def write(self, trace, layout=None, validate=True, stream_object.update(dict(layout=layout)) # TODO: allow string version of this? - jdata = json.dumps(stream_object, cls=utils.PlotlyJSONEncoder) + jdata = _json.dumps(stream_object, cls=utils.PlotlyJSONEncoder) jdata += "\n" try: @@ -673,10 +730,6 @@ def get(figure_or_data, format='png', width=None, height=None, scale=None): "Invalid scale parameter. Scale must be a number." ) - headers = _api_v2.headers() - headers['plotly_version'] = version.__version__ - headers['content-type'] = 'application/json' - payload = {'figure': figure, 'format': format} if width is not None: payload['width'] = width @@ -684,38 +737,18 @@ def get(figure_or_data, format='png', width=None, height=None, scale=None): payload['height'] = height if scale is not None: payload['scale'] = scale - url = _api_v2.api_url('images/') - - res = requests.post( - url, data=json.dumps(payload, cls=utils.PlotlyJSONEncoder), - headers=headers, verify=get_config()['plotly_ssl_verification'], - ) - - headers = res.headers - if res.status_code == 200: - if ('content-type' in headers and - headers['content-type'] in ['image/png', 'image/jpeg', - 'application/pdf', - 'image/svg+xml']): - return res.content + response = v2.images.create(payload) - elif ('content-type' in headers and - 'json' in headers['content-type']): - return_data = json.loads(res.content) - return return_data['image'] - else: - try: - if ('content-type' in headers and - 'json' in headers['content-type']): - return_data = json.loads(res.content) - else: - return_data = {'error': res.content} - except: - raise exceptions.PlotlyError("The response " - "from plotly could " - "not be translated.") - raise exceptions.PlotlyError(return_data['error']) + headers = response.headers + if ('content-type' in headers and + headers['content-type'] in ['image/png', 'image/jpeg', + 'application/pdf', + 'image/svg+xml']): + return response.content + elif ('content-type' in headers and + 'json' in headers['content-type']): + return response.json()['image'] @classmethod def ishow(cls, figure_or_data, format='png', width=None, height=None, @@ -829,22 +862,8 @@ def mkdirs(cls, folder_path): >> mkdirs('new/folder/path') """ - # trim trailing slash TODO: necessesary? - if folder_path[-1] == '/': - folder_path = folder_path[0:-1] - - payload = { - 'path': folder_path - } - - url = _api_v2.api_url('folders') - - res = requests.post(url, data=payload, headers=_api_v2.headers(), - verify=get_config()['plotly_ssl_verification']) - - _api_v2.response_handler(res) - - return res.status_code + response = v2.folders.create({'path': folder_path}) + return response.status_code class grid_ops: @@ -874,6 +893,15 @@ def _fill_in_response_column_ids(cls, request_columns, req_col.id = '{0}:{1}'.format(grid_id, resp_col['uid']) response_columns.remove(resp_col) + @staticmethod + def ensure_uploaded(fid): + if fid: + return + raise exceptions.PlotlyError( + 'This operation requires that the grid has already been uploaded ' + 'to Plotly. Try `uploading` first.' + ) + @classmethod def upload(cls, grid, filename, world_readable=True, auto_open=True, meta=None): @@ -954,37 +982,32 @@ def upload(cls, grid, filename, payload = { 'filename': filename, - 'data': json.dumps(grid_json, cls=utils.PlotlyJSONEncoder), + 'data': grid_json, 'world_readable': world_readable } if parent_path != '': payload['parent_path'] = parent_path - upload_url = _api_v2.api_url('grids') + response = v2.grids.create(payload) - req = requests.post(upload_url, data=payload, - headers=_api_v2.headers(), - verify=get_config()['plotly_ssl_verification']) - - res = _api_v2.response_handler(req) - - response_columns = res['file']['cols'] - grid_id = res['file']['fid'] - grid_url = res['file']['web_url'] + parsed_content = response.json() + cols = parsed_content['file']['cols'] + fid = parsed_content['file']['fid'] + web_url = parsed_content['file']['web_url'] # mutate the grid columns with the id's returned from the server - cls._fill_in_response_column_ids(grid, response_columns, grid_id) + cls._fill_in_response_column_ids(grid, cols, fid) - grid.id = grid_id + grid.id = fid if meta is not None: meta_ops.upload(meta, grid=grid) if auto_open: - _open_url(grid_url) + _open_url(web_url) - return grid_url + return web_url @classmethod def append_columns(cls, columns, grid=None, grid_url=None): @@ -1024,7 +1047,9 @@ def append_columns(cls, columns, grid=None, grid_url=None): ``` """ - grid_id = _api_v2.parse_grid_id_args(grid, grid_url) + grid_id = parse_grid_id_args(grid, grid_url) + + grid_ops.ensure_uploaded(grid_id) # Verify unique column names column_names = [c.name for c in columns] @@ -1036,17 +1061,15 @@ def append_columns(cls, columns, grid=None, grid_url=None): err = exceptions.NON_UNIQUE_COLUMN_MESSAGE.format(duplicate_name) raise exceptions.InputError(err) - payload = { - 'cols': json.dumps(columns, cls=utils.PlotlyJSONEncoder) + # This is sorta gross, we need to double-encode this. + body = { + 'cols': _json.dumps(columns, cls=utils.PlotlyJSONEncoder) } + fid = grid_id + response = v2.grids.col_create(fid, body) + parsed_content = response.json() - api_url = (_api_v2.api_url('grids') + - '/{grid_id}/col'.format(grid_id=grid_id)) - res = requests.post(api_url, data=payload, headers=_api_v2.headers(), - verify=get_config()['plotly_ssl_verification']) - res = _api_v2.response_handler(res) - - cls._fill_in_response_column_ids(columns, res['cols'], grid_id) + cls._fill_in_response_column_ids(columns, parsed_content['cols'], fid) if grid: grid.extend(columns) @@ -1096,7 +1119,9 @@ def append_rows(cls, rows, grid=None, grid_url=None): ``` """ - grid_id = _api_v2.parse_grid_id_args(grid, grid_url) + grid_id = parse_grid_id_args(grid, grid_url) + + grid_ops.ensure_uploaded(grid_id) if grid: n_columns = len([column for column in grid]) @@ -1112,15 +1137,8 @@ def append_rows(cls, rows, grid=None, grid_url=None): n_columns, 'column' if n_columns == 1 else 'columns')) - payload = { - 'rows': json.dumps(rows, cls=utils.PlotlyJSONEncoder) - } - - api_url = (_api_v2.api_url('grids') + - '/{grid_id}/row'.format(grid_id=grid_id)) - res = requests.post(api_url, data=payload, headers=_api_v2.headers(), - verify=get_config()['plotly_ssl_verification']) - _api_v2.response_handler(res) + fid = grid_id + v2.grids.row(fid, {'rows': rows}) if grid: longest_column_length = max([len(col.data) for col in grid]) @@ -1168,11 +1186,10 @@ def delete(cls, grid=None, grid_url=None): ``` """ - grid_id = _api_v2.parse_grid_id_args(grid, grid_url) - api_url = _api_v2.api_url('grids') + '/' + grid_id - res = requests.delete(api_url, headers=_api_v2.headers(), - verify=get_config()['plotly_ssl_verification']) - _api_v2.response_handler(res) + fid = parse_grid_id_args(grid, grid_url) + grid_ops.ensure_uploaded(fid) + v2.grids.trash(fid) + v2.grids.permanent_delete(fid) class meta_ops: @@ -1230,269 +1247,99 @@ def upload(cls, meta, grid=None, grid_url=None): ``` """ - grid_id = _api_v2.parse_grid_id_args(grid, grid_url) + fid = parse_grid_id_args(grid, grid_url) + return v2.grids.update(fid, {'metadata': meta}).json() - payload = { - 'metadata': json.dumps(meta, cls=utils.PlotlyJSONEncoder) - } - - api_url = _api_v2.api_url('grids') + '/{grid_id}'.format(grid_id=grid_id) - - res = requests.patch(api_url, data=payload, headers=_api_v2.headers(), - verify=get_config()['plotly_ssl_verification']) - - return _api_v2.response_handler(res) - - -class _api_v2: - """ - Request and response helper class for communicating with Plotly's v2 API +def parse_grid_id_args(grid, grid_url): """ - @classmethod - def parse_grid_id_args(cls, grid, grid_url): - """ - Return the grid_id from the non-None input argument. - - Raise an error if more than one argument was supplied. - - """ - if grid is not None: - id_from_grid = grid.id - else: - id_from_grid = None - args = [id_from_grid, grid_url] - arg_names = ('grid', 'grid_url') - - supplied_arg_names = [arg_name for arg_name, arg - in zip(arg_names, args) if arg is not None] - - if not supplied_arg_names: - raise exceptions.InputError( - "One of the two keyword arguments is required:\n" - " `grid` or `grid_url`\n\n" - "grid: a plotly.graph_objs.Grid object that has already\n" - " been uploaded to Plotly.\n\n" - "grid_url: the url where the grid can be accessed on\n" - " Plotly, e.g. 'https://plot.ly/~chris/3043'\n\n" - ) - elif len(supplied_arg_names) > 1: - raise exceptions.InputError( - "Only one of `grid` or `grid_url` is required. \n" - "You supplied both. \n" - ) - else: - supplied_arg_name = supplied_arg_names.pop() - if supplied_arg_name == 'grid_url': - path = six.moves.urllib.parse.urlparse(grid_url).path - file_owner, file_id = path.replace("/~", "").split('/')[0:2] - return '{0}:{1}'.format(file_owner, file_id) - else: - return grid.id - - @classmethod - def response_handler(cls, response): - try: - response.raise_for_status() - except requests.exceptions.HTTPError as requests_exception: - if (response.status_code == 404 and - get_config()['plotly_api_domain'] - != tools.get_config_defaults()['plotly_api_domain']): - raise exceptions.PlotlyError( - "This endpoint is unavailable at {url}. If you are using " - "Plotly On-Premise, you may need to upgrade your Plotly " - "Plotly On-Premise server to request against this endpoint or " - "this endpoint may not be available yet.\nQuestions? " - "Visit community.plot.ly, contact your plotly administrator " - "or upgrade to a Pro account for 1-1 help: https://goo.gl/1YUVu9 " - .format(url=get_config()['plotly_api_domain']) - ) - else: - raise requests_exception + Return the grid_id from the non-None input argument. - if ('content-type' in response.headers and - 'json' in response.headers['content-type'] and - len(response.content) > 0): - - response_dict = json.loads(response.content.decode('utf8')) - - if 'warnings' in response_dict and len(response_dict['warnings']): - warnings.warn('\n'.join(response_dict['warnings'])) - - return response_dict - - @classmethod - def api_url(cls, resource): - return ('{0}/v2/{1}'.format(get_config()['plotly_api_domain'], - resource)) - - @classmethod - def headers(cls): - credentials = get_credentials() + Raise an error if more than one argument was supplied. - # todo, validate here? - username, api_key = credentials['username'], credentials['api_key'] - encoded_api_auth = base64.b64encode(six.b('{0}:{1}'.format( - username, api_key))).decode('utf8') - - headers = { - 'plotly-client-platform': 'python {0}'.format(version.__version__) - } - - if get_config()['plotly_proxy_authorization']: - proxy_username = credentials['proxy_username'] - proxy_password = credentials['proxy_password'] - encoded_proxy_auth = base64.b64encode(six.b('{0}:{1}'.format( - proxy_username, proxy_password))).decode('utf8') - headers['authorization'] = 'Basic ' + encoded_proxy_auth - headers['plotly-authorization'] = 'Basic ' + encoded_api_auth - else: - headers['authorization'] = 'Basic ' + encoded_api_auth - - return headers - - -def validate_credentials(credentials): """ - Currently only checks for truthy username and api_key - - """ - username = credentials.get('username') - api_key = credentials.get('api_key') - if not username or not api_key: - raise exceptions.PlotlyLocalCredentialsError() + if grid is not None: + id_from_grid = grid.id + else: + id_from_grid = None + args = [id_from_grid, grid_url] + arg_names = ('grid', 'grid_url') + + supplied_arg_names = [arg_name for arg_name, arg + in zip(arg_names, args) if arg is not None] + + if not supplied_arg_names: + raise exceptions.InputError( + "One of the two keyword arguments is required:\n" + " `grid` or `grid_url`\n\n" + "grid: a plotly.graph_objs.Grid object that has already\n" + " been uploaded to Plotly.\n\n" + "grid_url: the url where the grid can be accessed on\n" + " Plotly, e.g. 'https://plot.ly/~chris/3043'\n\n" + ) + elif len(supplied_arg_names) > 1: + raise exceptions.InputError( + "Only one of `grid` or `grid_url` is required. \n" + "You supplied both. \n" + ) + else: + supplied_arg_name = supplied_arg_names.pop() + if supplied_arg_name == 'grid_url': + path = six.moves.urllib.parse.urlparse(grid_url).path + file_owner, file_id = path.replace("/~", "").split('/')[0:2] + return '{0}:{1}'.format(file_owner, file_id) + else: + return grid.id -def add_share_key_to_url(plot_url, attempt=0): +def add_share_key_to_url(plot_url): """ Update plot's url to include the secret key """ urlsplit = six.moves.urllib.parse.urlparse(plot_url) - file_owner = urlsplit.path.split('/')[1].split('~')[1] - file_id = urlsplit.path.split('/')[2] + username = urlsplit.path.split('/')[1].split('~')[1] + idlocal = urlsplit.path.split('/')[2] + fid = '{}:{}'.format(username, idlocal) - url = _api_v2.api_url("files/") + file_owner + ":" + file_id - new_response = requests.patch(url, - headers=_api_v2.headers(), - data={"share_key_enabled": - "True", - "world_readable": - "False"}) + body = {'share_key_enabled': True, 'world_readable': False} + response = v2.files.update(fid, body) - _api_v2.response_handler(new_response) - - # decode bytes for python 3.3: https://bugs.python.org/issue10976 - str_content = new_response.content.decode('utf-8') - - new_response_data = json.loads(str_content) - - plot_url += '?share_key=' + new_response_data['share_key'] - - # sometimes a share key is added, but access is still denied - # check for access, and retry a couple of times if this is the case - # https://github.com/plotly/streambed/issues/4089 - embed_url = plot_url.split('?')[0] + '.embed' + plot_url.split('?')[1] - access_res = requests.get(embed_url) - if access_res.status_code == 404: - attempt += 1 - if attempt == 5: - return plot_url - plot_url = add_share_key_to_url(plot_url.split('?')[0], attempt) - - return plot_url + return plot_url + '?share_key=' + response.json()['share_key'] def _send_to_plotly(figure, **plot_options): fig = tools._replace_newline(figure) # does not mutate figure - data = json.dumps(fig['data'] if 'data' in fig else [], - cls=utils.PlotlyJSONEncoder) - credentials = get_credentials() - validate_credentials(credentials) - username = credentials['username'] - api_key = credentials['api_key'] - kwargs = json.dumps(dict(filename=plot_options['filename'], - fileopt=plot_options['fileopt'], - world_readable=plot_options['world_readable'], - sharing=plot_options['sharing'], - layout=fig['layout'] if 'layout' in fig else {}), - cls=utils.PlotlyJSONEncoder) - - # TODO: It'd be cool to expose the platform for RaspPi and others - payload = dict(platform='python', - version=version.__version__, - args=data, - un=username, - key=api_key, - origin='plot', - kwargs=kwargs) - - url = get_config()['plotly_domain'] + "/clientresp" - - r = requests.post(url, data=payload, - verify=get_config()['plotly_ssl_verification']) - r.raise_for_status() - r = json.loads(r.text) - - if 'error' in r and r['error'] != '': - raise exceptions.PlotlyError(r['error']) - - # Check if the url needs a secret key - if (plot_options['sharing'] == 'secret' and - 'share_key=' not in r['url']): + data = fig.get('data', []) + response = v1.clientresp(data, **plot_options) - # add_share_key_to_url updates the url to include the share_key - r['url'] = add_share_key_to_url(r['url']) + parsed_content = response.json() - if 'error' in r and r['error'] != '': - print(r['error']) - if 'warning' in r and r['warning'] != '': - warnings.warn(r['warning']) - if 'message' in r and r['message'] != '': - print(r['message']) + # Check if the url needs a secret key + if plot_options['sharing'] == 'secret': + url = parsed_content['url'] + if 'share_key=' not in url: + # add_share_key_to_url updates the url to include the share_key + parsed_content['url'] = add_share_key_to_url(url) - return r + return parsed_content def get_grid(grid_url, raw=False): """ Returns the specified grid as a Grid instance or in JSON/dict form. + :param (str) grid_url: The web_url which locates a Plotly grid. :param (bool) raw: if False, will output a Grid instance of the JSON grid being retrieved. If True, raw JSON will be returned. """ - credentials = get_credentials() - validate_credentials(credentials) - username, api_key = credentials['username'], credentials['api_key'] - headers = {'plotly-username': username, - 'plotly-apikey': api_key, - 'plotly-version': version.__version__, - 'plotly-platform': 'python'} - upload_url = _api_v2.api_url('grids') - - # extract path in grid url - url_path = six.moves.urllib.parse.urlparse(grid_url)[2][2:] - if url_path[-1] == '/': - url_path = url_path[0: -1] - url_path = url_path.replace('/', ':') - - meta_get_url = upload_url + '/' + url_path - get_url = meta_get_url + '/content' - - r = requests.get(get_url, headers=headers) - json_res = json.loads(r.text) - - # make request to grab the grid id (fid) - r_meta = requests.get(meta_get_url, headers=headers) - r_meta.raise_for_status() - - json_res_meta = json.loads(r_meta.text) - retrieved_grid_id = json_res_meta['fid'] - - if raw is False: - return Grid(json_res, retrieved_grid_id) - else: - return json_res + fid = parse_grid_id_args(None, grid_url) + response = v2.grids.content(fid) + parsed_content = response.json() + + if raw: + return parsed_content + return Grid(parsed_content, fid) def create_animations(figure, filename=None, sharing='public', auto_open=True): @@ -1661,14 +1508,7 @@ def create_animations(figure, filename=None, sharing='public', auto_open=True): py.create_animations(figure, 'growing_circles') ``` """ - credentials = get_credentials() - validate_credentials(credentials) - username, api_key = credentials['username'], credentials['api_key'] - auth = HTTPBasicAuth(str(username), str(api_key)) - headers = {'Plotly-Client-Platform': 'python', - 'content-type': 'application/json'} - - json = { + body = { 'figure': figure, 'world_readable': True } @@ -1682,48 +1522,30 @@ def create_animations(figure, filename=None, sharing='public', auto_open=True): "automatic folder creation. This means a filename of the form " "'name1/name2' will just create the plot with that name only." ) - json['filename'] = filename + body['filename'] = filename # set sharing if sharing == 'public': - json['world_readable'] = True + body['world_readable'] = True elif sharing == 'private': - json['world_readable'] = False + body['world_readable'] = False elif sharing == 'secret': - json['world_readable'] = False - json['share_key_enabled'] = True + body['world_readable'] = False + body['share_key_enabled'] = True else: raise exceptions.PlotlyError( "Whoops, sharing can only be set to either 'public', 'private', " "or 'secret'." ) - api_url = _api_v2.api_url('plots') - r = requests.post(api_url, auth=auth, headers=headers, json=json) - - try: - parsed_response = r.json() - except: - parsed_response = r.content - - # raise error message - if not r.ok: - message = '' - if isinstance(parsed_response, dict): - errors = parsed_response.get('errors') - if errors and errors[-1].get('message'): - message = errors[-1]['message'] - if message: - raise exceptions.PlotlyError(message) - else: - # shucks, we're stuck with a generic error... - r.raise_for_status() + response = v2.plots.create(body) + parsed_content = response.json() if sharing == 'secret': - web_url = (parsed_response['file']['web_url'][:-1] + - '?share_key=' + parsed_response['file']['share_key']) + web_url = (parsed_content['file']['web_url'][:-1] + + '?share_key=' + parsed_content['file']['share_key']) else: - web_url = parsed_response['file']['web_url'] + web_url = parsed_content['file']['web_url'] if auto_open: _open_url(web_url) @@ -1738,7 +1560,6 @@ def icreate_animations(figure, filename=None, sharing='public', auto_open=False) This function is based off `plotly.plotly.iplot`. See `plotly.plotly. create_animations` Doc String for param descriptions. """ - # Still needs doing: create a wrapper for iplot and icreate_animations url = create_animations(figure, filename, sharing, auto_open) if isinstance(figure, dict): diff --git a/plotly/session.py b/plotly/session.py index e93d9a85996..2e72d45bff3 100644 --- a/plotly/session.py +++ b/plotly/session.py @@ -22,6 +22,8 @@ CREDENTIALS_KEYS = { 'username': six.string_types, 'api_key': six.string_types, + 'proxy_username': six.string_types, + 'proxy_password': six.string_types, 'stream_ids': list } diff --git a/plotly/tests/test_core/test_api/__init__.py b/plotly/tests/test_core/test_api/__init__.py new file mode 100644 index 00000000000..f8a93ee0238 --- /dev/null +++ b/plotly/tests/test_core/test_api/__init__.py @@ -0,0 +1,59 @@ +from __future__ import absolute_import + +from mock import patch +from requests import Response + +from plotly.session import sign_in +from plotly.tests.utils import PlotlyTestCase + + +class PlotlyApiTestCase(PlotlyTestCase): + + def mock(self, path_string): + patcher = patch(path_string) + new_mock = patcher.start() + self.addCleanup(patcher.stop) + return new_mock + + def setUp(self): + + super(PlotlyApiTestCase, self).setUp() + + self.username = 'foo' + self.api_key = 'bar' + + self.proxy_username = 'cnet' + self.proxy_password = 'hoopla' + self.stream_ids = ['heyThere'] + + self.plotly_api_domain = 'https://api.do.not.exist' + self.plotly_domain = 'https://who.am.i' + self.plotly_proxy_authorization = False + self.plotly_streaming_domain = 'stream.does.not.exist' + self.plotly_ssl_verification = True + + sign_in( + username=self.username, + api_key=self.api_key, + proxy_username=self.proxy_username, + proxy_password=self.proxy_password, + stream_ids = self.stream_ids, + plotly_domain=self.plotly_domain, + plotly_api_domain=self.plotly_api_domain, + plotly_streaming_domain=self.plotly_streaming_domain, + plotly_proxy_authorization=self.plotly_proxy_authorization, + plotly_ssl_verification=self.plotly_ssl_verification + ) + + def to_bytes(self, string): + try: + return string.encode('utf-8') + except AttributeError: + return string + + def get_response(self, content=b'', status_code=200): + response = Response() + response.status_code = status_code + response._content = content + response.encoding = 'utf-8' + return response diff --git a/plotly/tests/test_core/test_api/test_v1/__init__.py b/plotly/tests/test_core/test_api/test_v1/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/plotly/tests/test_core/test_api/test_v1/test_clientresp.py b/plotly/tests/test_core/test_api/test_v1/test_clientresp.py new file mode 100644 index 00000000000..784ca087642 --- /dev/null +++ b/plotly/tests/test_core/test_api/test_v1/test_clientresp.py @@ -0,0 +1,61 @@ +from __future__ import absolute_import + +from plotly import version +from plotly.api.v1 import clientresp +from plotly.tests.test_core.test_api import PlotlyApiTestCase + + +class Duck(object): + def to_plotly_json(self): + return 'what else floats?' + + +class ClientrespTest(PlotlyApiTestCase): + + def setUp(self): + super(ClientrespTest, self).setUp() + + # Mock the actual api call, we don't want to do network tests here. + self.request_mock = self.mock('plotly.api.v1.utils.requests.request') + self.request_mock.return_value = self.get_response(b'{}', 200) + + # Mock the validation function since we can test that elsewhere. + self.mock('plotly.api.v1.utils.validate_response') + + def test_data_only(self): + data = [{'y': [3, 5], 'name': Duck()}] + clientresp(data) + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'post') + self.assertEqual(url, '{}/clientresp'.format(self.plotly_domain)) + expected_data = ({ + 'origin': 'plot', + 'args': '[{"name": "what else floats?", "y": [3, 5]}]', + 'platform': 'python', 'version': version.__version__, 'key': 'bar', + 'kwargs': '{}', 'un': 'foo' + }) + self.assertEqual(kwargs['data'], expected_data) + self.assertTrue(kwargs['verify']) + self.assertEqual(kwargs['headers'], {}) + + def test_data_and_kwargs(self): + data = [{'y': [3, 5], 'name': Duck()}] + clientresp_kwargs = {'layout': {'title': 'mah plot'}, 'filename': 'ok'} + clientresp(data, **clientresp_kwargs) + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'post') + self.assertEqual(url, '{}/clientresp'.format(self.plotly_domain)) + expected_data = ({ + 'origin': 'plot', + 'args': '[{"name": "what else floats?", "y": [3, 5]}]', + 'platform': 'python', 'version': version.__version__, 'key': 'bar', + 'kwargs': '{"filename": "ok", "layout": {"title": "mah plot"}}', + 'un': 'foo' + }) + self.assertEqual(kwargs['data'], expected_data) + self.assertTrue(kwargs['verify']) + self.assertEqual(kwargs['headers'], {}) diff --git a/plotly/tests/test_core/test_api/test_v1/test_utils.py b/plotly/tests/test_core/test_api/test_v1/test_utils.py new file mode 100644 index 00000000000..dee352db785 --- /dev/null +++ b/plotly/tests/test_core/test_api/test_v1/test_utils.py @@ -0,0 +1,175 @@ +from __future__ import absolute_import + +from unittest import TestCase + +from mock import MagicMock, patch +from requests import Response +from requests.compat import json as _json +from requests.exceptions import ConnectionError + +from plotly.api.utils import to_native_utf8_string +from plotly.api.v1 import utils +from plotly.exceptions import PlotlyError, PlotlyRequestError +from plotly.session import sign_in +from plotly.tests.test_core.test_api import PlotlyApiTestCase +from plotly.tests.utils import PlotlyTestCase + + +class ValidateResponseTest(PlotlyApiTestCase): + + def test_validate_ok(self): + try: + utils.validate_response(self.get_response(content=b'{}')) + except PlotlyRequestError: + self.fail('Expected this to pass!') + + def test_validate_not_ok(self): + bad_status_codes = (400, 404, 500) + for bad_status_code in bad_status_codes: + response = self.get_response(content=b'{}', + status_code=bad_status_code) + self.assertRaises(PlotlyRequestError, utils.validate_response, + response) + + def test_validate_no_content(self): + + # We shouldn't flake if the response has no content. + + response = self.get_response(content=b'', status_code=200) + try: + utils.validate_response(response) + except PlotlyRequestError as e: + self.assertEqual(e.message, 'No Content') + self.assertEqual(e.status_code, 200) + self.assertEqual(e.content, b'') + else: + self.fail('Expected this to raise!') + + def test_validate_non_json_content(self): + response = self.get_response(content=b'foobar', status_code=200) + try: + utils.validate_response(response) + except PlotlyRequestError as e: + self.assertEqual(e.message, 'foobar') + self.assertEqual(e.status_code, 200) + self.assertEqual(e.content, b'foobar') + else: + self.fail('Expected this to raise!') + + def test_validate_json_content_array(self): + content = self.to_bytes(_json.dumps([1, 2, 3])) + response = self.get_response(content=content, status_code=200) + try: + utils.validate_response(response) + except PlotlyRequestError as e: + self.assertEqual(e.message, to_native_utf8_string(content)) + self.assertEqual(e.status_code, 200) + self.assertEqual(e.content, content) + else: + self.fail('Expected this to raise!') + + def test_validate_json_content_dict_no_error(self): + content = self.to_bytes(_json.dumps({'foo': 'bar'})) + response = self.get_response(content=content, status_code=400) + try: + utils.validate_response(response) + except PlotlyRequestError as e: + self.assertEqual(e.message, to_native_utf8_string(content)) + self.assertEqual(e.status_code, 400) + self.assertEqual(e.content, content) + else: + self.fail('Expected this to raise!') + + def test_validate_json_content_dict_error_empty(self): + content = self.to_bytes(_json.dumps({'error': ''})) + response = self.get_response(content=content, status_code=200) + try: + utils.validate_response(response) + except PlotlyRequestError: + self.fail('Expected this not to raise!') + + def test_validate_json_content_dict_one_error_ok(self): + content = self.to_bytes(_json.dumps({'error': 'not ok!'})) + response = self.get_response(content=content, status_code=200) + try: + utils.validate_response(response) + except PlotlyRequestError as e: + self.assertEqual(e.message, 'not ok!') + self.assertEqual(e.status_code, 200) + self.assertEqual(e.content, content) + else: + self.fail('Expected this to raise!') + + +class GetHeadersTest(PlotlyTestCase): + + def setUp(self): + super(GetHeadersTest, self).setUp() + self.domain = 'https://foo.bar' + self.username = 'hodor' + self.api_key = 'secret' + sign_in(self.username, self.api_key, proxy_username='kleen-kanteen', + proxy_password='hydrated', plotly_proxy_authorization=False) + + def test_normal_auth(self): + headers = utils.get_headers() + expected_headers = {} + self.assertEqual(headers, expected_headers) + + def test_proxy_auth(self): + sign_in(self.username, self.api_key, plotly_proxy_authorization=True) + headers = utils.get_headers() + expected_headers = { + 'authorization': 'Basic a2xlZW4ta2FudGVlbjpoeWRyYXRlZA==' + } + self.assertEqual(headers, expected_headers) + + +class RequestTest(PlotlyTestCase): + + def setUp(self): + super(RequestTest, self).setUp() + self.domain = 'https://foo.bar' + self.username = 'hodor' + self.api_key = 'secret' + sign_in(self.username, self.api_key, proxy_username='kleen-kanteen', + proxy_password='hydrated', plotly_proxy_authorization=False) + + # Mock the actual api call, we don't want to do network tests here. + patcher = patch('plotly.api.v1.utils.requests.request') + self.request_mock = patcher.start() + self.addCleanup(patcher.stop) + self.request_mock.return_value = MagicMock(Response) + + # Mock the validation function since we test that elsewhere. + patcher = patch('plotly.api.v1.utils.validate_response') + self.validate_response_mock = patcher.start() + self.addCleanup(patcher.stop) + + self.method = 'get' + self.url = 'https://foo.bar.does.not.exist.anywhere' + + def test_request_with_json(self): + + # You can pass along non-native objects in the `json` kwarg for a + # requests.request, however, V1 packs up json objects a little + # differently, so we don't allow such requests. + + self.assertRaises(PlotlyError, utils.request, self.method, + self.url, json={}) + + def test_request_with_ConnectionError(self): + + # requests can flake out and not return a response object, we want to + # make sure we remain consistent with our errors. + + self.request_mock.side_effect = ConnectionError() + self.assertRaises(PlotlyRequestError, utils.request, self.method, + self.url) + + def test_request_validate_response(self): + + # Finally, we check details elsewhere, but make sure we do validate. + + utils.request(self.method, self.url) + self.validate_response_mock.assert_called_once() diff --git a/plotly/tests/test_core/test_api/test_v2/__init__.py b/plotly/tests/test_core/test_api/test_v2/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/plotly/tests/test_core/test_api/test_v2/test_files.py b/plotly/tests/test_core/test_api/test_v2/test_files.py new file mode 100644 index 00000000000..32e4ec99347 --- /dev/null +++ b/plotly/tests/test_core/test_api/test_v2/test_files.py @@ -0,0 +1,104 @@ +from __future__ import absolute_import + +from plotly.api.v2 import files +from plotly.tests.test_core.test_api import PlotlyApiTestCase + + +class FilesTest(PlotlyApiTestCase): + + def setUp(self): + super(FilesTest, self).setUp() + + # Mock the actual api call, we don't want to do network tests here. + self.request_mock = self.mock('plotly.api.v2.utils.requests.request') + self.request_mock.return_value = self.get_response() + + # Mock the validation function since we can test that elsewhere. + self.mock('plotly.api.v2.utils.validate_response') + + def test_retrieve(self): + files.retrieve('hodor:88') + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'get') + self.assertEqual( + url, '{}/v2/files/hodor:88'.format(self.plotly_api_domain) + ) + self.assertEqual(kwargs['params'], {}) + + def test_retrieve_share_key(self): + files.retrieve('hodor:88', share_key='foobar') + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'get') + self.assertEqual( + url, '{}/v2/files/hodor:88'.format(self.plotly_api_domain) + ) + self.assertEqual(kwargs['params'], {'share_key': 'foobar'}) + + def test_update(self): + new_filename = '..zzZ ..zzZ' + files.update('hodor:88', body={'filename': new_filename}) + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'put') + self.assertEqual( + url, '{}/v2/files/hodor:88'.format(self.plotly_api_domain) + ) + self.assertEqual(kwargs['data'], + '{{"filename": "{}"}}'.format(new_filename)) + + def test_trash(self): + files.trash('hodor:88') + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'post') + self.assertEqual( + url, '{}/v2/files/hodor:88/trash'.format(self.plotly_api_domain) + ) + + def test_restore(self): + files.restore('hodor:88') + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'post') + self.assertEqual( + url, '{}/v2/files/hodor:88/restore'.format(self.plotly_api_domain) + ) + + def test_permanent_delete(self): + files.permanent_delete('hodor:88') + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'delete') + self.assertEqual( + url, + '{}/v2/files/hodor:88/permanent_delete' + .format(self.plotly_api_domain) + ) + + def test_lookup(self): + + # requests does urlencode, so don't worry about the `' '` character! + + path = '/mah plot' + parent = 43 + user = 'someone' + exists = True + files.lookup(path=path, parent=parent, user=user, exists=exists) + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + expected_params = {'path': path, 'parent': parent, 'exists': 'true', + 'user': user} + self.assertEqual(method, 'get') + self.assertEqual( + url, '{}/v2/files/lookup'.format(self.plotly_api_domain) + ) + self.assertEqual(kwargs['params'], expected_params) diff --git a/plotly/tests/test_core/test_api/test_v2/test_folders.py b/plotly/tests/test_core/test_api/test_v2/test_folders.py new file mode 100644 index 00000000000..0365ad79879 --- /dev/null +++ b/plotly/tests/test_core/test_api/test_v2/test_folders.py @@ -0,0 +1,114 @@ +from __future__ import absolute_import + +from plotly.api.v2 import folders +from plotly.tests.test_core.test_api import PlotlyApiTestCase + + +class FoldersTest(PlotlyApiTestCase): + + def setUp(self): + super(FoldersTest, self).setUp() + + # Mock the actual api call, we don't want to do network tests here. + self.request_mock = self.mock('plotly.api.v2.utils.requests.request') + self.request_mock.return_value = self.get_response() + + # Mock the validation function since we can test that elsewhere. + self.mock('plotly.api.v2.utils.validate_response') + + def test_create(self): + path = '/foo/man/bar/' + folders.create({'path': path}) + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'post') + self.assertEqual(url, '{}/v2/folders'.format(self.plotly_api_domain)) + self.assertEqual(kwargs['data'], '{{"path": "{}"}}'.format(path)) + + def test_retrieve(self): + folders.retrieve('hodor:88') + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'get') + self.assertEqual( + url, '{}/v2/folders/hodor:88'.format(self.plotly_api_domain) + ) + self.assertEqual(kwargs['params'], {}) + + def test_retrieve_share_key(self): + folders.retrieve('hodor:88', share_key='foobar') + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'get') + self.assertEqual( + url, '{}/v2/folders/hodor:88'.format(self.plotly_api_domain) + ) + self.assertEqual(kwargs['params'], {'share_key': 'foobar'}) + + def test_update(self): + new_filename = '..zzZ ..zzZ' + folders.update('hodor:88', body={'filename': new_filename}) + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'put') + self.assertEqual( + url, '{}/v2/folders/hodor:88'.format(self.plotly_api_domain) + ) + self.assertEqual(kwargs['data'], + '{{"filename": "{}"}}'.format(new_filename)) + + def test_trash(self): + folders.trash('hodor:88') + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'post') + self.assertEqual( + url, '{}/v2/folders/hodor:88/trash'.format(self.plotly_api_domain) + ) + + def test_restore(self): + folders.restore('hodor:88') + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'post') + self.assertEqual( + url, '{}/v2/folders/hodor:88/restore'.format(self.plotly_api_domain) + ) + + def test_permanent_delete(self): + folders.permanent_delete('hodor:88') + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'delete') + self.assertEqual( + url, + '{}/v2/folders/hodor:88/permanent_delete' + .format(self.plotly_api_domain) + ) + + def test_lookup(self): + + # requests does urlencode, so don't worry about the `' '` character! + + path = '/mah folder' + parent = 43 + user = 'someone' + exists = True + folders.lookup(path=path, parent=parent, user=user, exists=exists) + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + expected_params = {'path': path, 'parent': parent, 'exists': 'true', + 'user': user} + self.assertEqual(method, 'get') + self.assertEqual( + url, '{}/v2/folders/lookup'.format(self.plotly_api_domain) + ) + self.assertEqual(kwargs['params'], expected_params) diff --git a/plotly/tests/test_core/test_api/test_v2/test_grids.py b/plotly/tests/test_core/test_api/test_v2/test_grids.py new file mode 100644 index 00000000000..ff6fb3ec1b3 --- /dev/null +++ b/plotly/tests/test_core/test_api/test_v2/test_grids.py @@ -0,0 +1,185 @@ +from __future__ import absolute_import + +from requests.compat import json as _json + +from plotly.api.v2 import grids +from plotly.tests.test_core.test_api import PlotlyApiTestCase + + +class GridsTest(PlotlyApiTestCase): + + def setUp(self): + super(GridsTest, self).setUp() + + # Mock the actual api call, we don't want to do network tests here. + self.request_mock = self.mock('plotly.api.v2.utils.requests.request') + self.request_mock.return_value = self.get_response() + + # Mock the validation function since we can test that elsewhere. + self.mock('plotly.api.v2.utils.validate_response') + + def test_create(self): + filename = 'a grid' + grids.create({'filename': filename}) + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'post') + self.assertEqual(url, '{}/v2/grids'.format(self.plotly_api_domain)) + self.assertEqual( + kwargs['data'], '{{"filename": "{}"}}'.format(filename) + ) + + def test_retrieve(self): + grids.retrieve('hodor:88') + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'get') + self.assertEqual( + url, '{}/v2/grids/hodor:88'.format(self.plotly_api_domain) + ) + self.assertEqual(kwargs['params'], {}) + + def test_retrieve_share_key(self): + grids.retrieve('hodor:88', share_key='foobar') + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'get') + self.assertEqual( + url, '{}/v2/grids/hodor:88'.format(self.plotly_api_domain) + ) + self.assertEqual(kwargs['params'], {'share_key': 'foobar'}) + + def test_update(self): + new_filename = '..zzZ ..zzZ' + grids.update('hodor:88', body={'filename': new_filename}) + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'put') + self.assertEqual( + url, '{}/v2/grids/hodor:88'.format(self.plotly_api_domain) + ) + self.assertEqual(kwargs['data'], + '{{"filename": "{}"}}'.format(new_filename)) + + def test_trash(self): + grids.trash('hodor:88') + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'post') + self.assertEqual( + url, '{}/v2/grids/hodor:88/trash'.format(self.plotly_api_domain) + ) + + def test_restore(self): + grids.restore('hodor:88') + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'post') + self.assertEqual( + url, '{}/v2/grids/hodor:88/restore'.format(self.plotly_api_domain) + ) + + def test_permanent_delete(self): + grids.permanent_delete('hodor:88') + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'delete') + self.assertEqual( + url, + '{}/v2/grids/hodor:88/permanent_delete' + .format(self.plotly_api_domain) + ) + + def test_lookup(self): + + # requests does urlencode, so don't worry about the `' '` character! + + path = '/mah grid' + parent = 43 + user = 'someone' + exists = True + grids.lookup(path=path, parent=parent, user=user, exists=exists) + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + expected_params = {'path': path, 'parent': parent, 'exists': 'true', + 'user': user} + self.assertEqual(method, 'get') + self.assertEqual( + url, '{}/v2/grids/lookup'.format(self.plotly_api_domain) + ) + self.assertEqual(kwargs['params'], expected_params) + + def test_col_create(self): + cols = [ + {'name': 'foo', 'data': [1, 2, 3]}, + {'name': 'bar', 'data': ['b', 'a', 'r']}, + ] + body = {'cols': _json.dumps(cols, sort_keys=True)} + grids.col_create('hodor:88', body) + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'post') + self.assertEqual( + url, '{}/v2/grids/hodor:88/col'.format(self.plotly_api_domain) + ) + self.assertEqual(kwargs['data'], _json.dumps(body, sort_keys=True)) + + def test_col_retrieve(self): + grids.col_retrieve('hodor:88', 'aaaaaa,bbbbbb') + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'get') + self.assertEqual( + url, '{}/v2/grids/hodor:88/col'.format(self.plotly_api_domain) + ) + self.assertEqual(kwargs['params'], {'uid': 'aaaaaa,bbbbbb'}) + + def test_col_update(self): + cols = [ + {'name': 'foo', 'data': [1, 2, 3]}, + {'name': 'bar', 'data': ['b', 'a', 'r']}, + ] + body = {'cols': _json.dumps(cols, sort_keys=True)} + grids.col_update('hodor:88', 'aaaaaa,bbbbbb', body) + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'put') + self.assertEqual( + url, '{}/v2/grids/hodor:88/col'.format(self.plotly_api_domain) + ) + self.assertEqual(kwargs['params'], {'uid': 'aaaaaa,bbbbbb'}) + self.assertEqual(kwargs['data'], _json.dumps(body, sort_keys=True)) + + def test_col_delete(self): + grids.col_delete('hodor:88', 'aaaaaa,bbbbbb') + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'delete') + self.assertEqual( + url, '{}/v2/grids/hodor:88/col'.format(self.plotly_api_domain) + ) + self.assertEqual(kwargs['params'], {'uid': 'aaaaaa,bbbbbb'}) + + def test_row(self): + body = {'rows': [[1, 'A'], [2, 'B']]} + grids.row('hodor:88', body) + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'post') + self.assertEqual( + url, '{}/v2/grids/hodor:88/row'.format(self.plotly_api_domain) + ) + self.assertEqual(kwargs['data'], _json.dumps(body, sort_keys=True)) diff --git a/plotly/tests/test_core/test_api/test_v2/test_images.py b/plotly/tests/test_core/test_api/test_v2/test_images.py new file mode 100644 index 00000000000..480cf0f05bf --- /dev/null +++ b/plotly/tests/test_core/test_api/test_v2/test_images.py @@ -0,0 +1,41 @@ +from __future__ import absolute_import + +from requests.compat import json as _json + +from plotly.api.v2 import images +from plotly.tests.test_core.test_api import PlotlyApiTestCase + + +class ImagesTest(PlotlyApiTestCase): + + def setUp(self): + super(ImagesTest, self).setUp() + + # Mock the actual api call, we don't want to do network tests here. + self.request_mock = self.mock('plotly.api.v2.utils.requests.request') + self.request_mock.return_value = self.get_response() + + # Mock the validation function since we can test that elsewhere. + self.mock('plotly.api.v2.utils.validate_response') + + def test_create(self): + + body = { + "figure": { + "data": [{"y": [10, 10, 2, 20]}], + "layout": {"width": 700} + }, + "width": 1000, + "height": 500, + "format": "png", + "scale": 4, + "encoded": False + } + + images.create(body) + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'post') + self.assertEqual(url, '{}/v2/images'.format(self.plotly_api_domain)) + self.assertEqual(kwargs['data'], _json.dumps(body, sort_keys=True)) diff --git a/plotly/tests/test_core/test_api/test_v2/test_plot_schema.py b/plotly/tests/test_core/test_api/test_v2/test_plot_schema.py new file mode 100644 index 00000000000..b52f1b3a000 --- /dev/null +++ b/plotly/tests/test_core/test_api/test_v2/test_plot_schema.py @@ -0,0 +1,30 @@ +from __future__ import absolute_import + +from plotly.api.v2 import plot_schema +from plotly.tests.test_core.test_api import PlotlyApiTestCase + + +class PlotSchemaTest(PlotlyApiTestCase): + + def setUp(self): + super(PlotSchemaTest, self).setUp() + + # Mock the actual api call, we don't want to do network tests here. + self.request_mock = self.mock('plotly.api.v2.utils.requests.request') + self.request_mock.return_value = self.get_response() + + # Mock the validation function since we can test that elsewhere. + self.mock('plotly.api.v2.utils.validate_response') + + def test_retrieve(self): + + plot_schema.retrieve('some-hash', timeout=400) + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'get') + self.assertEqual( + url, '{}/v2/plot-schema'.format(self.plotly_api_domain) + ) + self.assertTrue(kwargs['timeout']) + self.assertEqual(kwargs['params'], {'sha1': 'some-hash'}) diff --git a/plotly/tests/test_core/test_api/test_v2/test_plots.py b/plotly/tests/test_core/test_api/test_v2/test_plots.py new file mode 100644 index 00000000000..31d50cb7aaf --- /dev/null +++ b/plotly/tests/test_core/test_api/test_v2/test_plots.py @@ -0,0 +1,116 @@ +from __future__ import absolute_import + +from plotly.api.v2 import plots +from plotly.tests.test_core.test_api import PlotlyApiTestCase + + +class PlotsTest(PlotlyApiTestCase): + + def setUp(self): + super(PlotsTest, self).setUp() + + # Mock the actual api call, we don't want to do network tests here. + self.request_mock = self.mock('plotly.api.v2.utils.requests.request') + self.request_mock.return_value = self.get_response() + + # Mock the validation function since we can test that elsewhere. + self.mock('plotly.api.v2.utils.validate_response') + + def test_create(self): + filename = 'a plot' + plots.create({'filename': filename}) + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'post') + self.assertEqual(url, '{}/v2/plots'.format(self.plotly_api_domain)) + self.assertEqual( + kwargs['data'], '{{"filename": "{}"}}'.format(filename) + ) + + def test_retrieve(self): + plots.retrieve('hodor:88') + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'get') + self.assertEqual( + url, '{}/v2/plots/hodor:88'.format(self.plotly_api_domain) + ) + self.assertEqual(kwargs['params'], {}) + + def test_retrieve_share_key(self): + plots.retrieve('hodor:88', share_key='foobar') + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'get') + self.assertEqual( + url, '{}/v2/plots/hodor:88'.format(self.plotly_api_domain) + ) + self.assertEqual(kwargs['params'], {'share_key': 'foobar'}) + + def test_update(self): + new_filename = '..zzZ ..zzZ' + plots.update('hodor:88', body={'filename': new_filename}) + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'put') + self.assertEqual( + url, '{}/v2/plots/hodor:88'.format(self.plotly_api_domain) + ) + self.assertEqual(kwargs['data'], + '{{"filename": "{}"}}'.format(new_filename)) + + def test_trash(self): + plots.trash('hodor:88') + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'post') + self.assertEqual( + url, '{}/v2/plots/hodor:88/trash'.format(self.plotly_api_domain) + ) + + def test_restore(self): + plots.restore('hodor:88') + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'post') + self.assertEqual( + url, '{}/v2/plots/hodor:88/restore'.format(self.plotly_api_domain) + ) + + def test_permanent_delete(self): + plots.permanent_delete('hodor:88') + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'delete') + self.assertEqual( + url, + '{}/v2/plots/hodor:88/permanent_delete' + .format(self.plotly_api_domain) + ) + + def test_lookup(self): + + # requests does urlencode, so don't worry about the `' '` character! + + path = '/mah plot' + parent = 43 + user = 'someone' + exists = True + plots.lookup(path=path, parent=parent, user=user, exists=exists) + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + expected_params = {'path': path, 'parent': parent, 'exists': 'true', + 'user': user} + self.assertEqual(method, 'get') + self.assertEqual( + url, '{}/v2/plots/lookup'.format(self.plotly_api_domain) + ) + self.assertEqual(kwargs['params'], expected_params) diff --git a/plotly/tests/test_core/test_api/test_v2/test_users.py b/plotly/tests/test_core/test_api/test_v2/test_users.py new file mode 100644 index 00000000000..59cf8731d56 --- /dev/null +++ b/plotly/tests/test_core/test_api/test_v2/test_users.py @@ -0,0 +1,28 @@ +from __future__ import absolute_import + +from plotly.api.v2 import users +from plotly.tests.test_core.test_api import PlotlyApiTestCase + + +class UsersTest(PlotlyApiTestCase): + + def setUp(self): + super(UsersTest, self).setUp() + + # Mock the actual api call, we don't want to do network tests here. + self.request_mock = self.mock('plotly.api.v2.utils.requests.request') + self.request_mock.return_value = self.get_response() + + # Mock the validation function since we can test that elsewhere. + self.mock('plotly.api.v2.utils.validate_response') + + def test_current(self): + users.current() + self.request_mock.assert_called_once() + args, kwargs = self.request_mock.call_args + method, url = args + self.assertEqual(method, 'get') + self.assertEqual( + url, '{}/v2/users/current'.format(self.plotly_api_domain) + ) + self.assertNotIn('params', kwargs) diff --git a/plotly/tests/test_core/test_api/test_v2/test_utils.py b/plotly/tests/test_core/test_api/test_v2/test_utils.py new file mode 100644 index 00000000000..c370ef418ed --- /dev/null +++ b/plotly/tests/test_core/test_api/test_v2/test_utils.py @@ -0,0 +1,252 @@ +from __future__ import absolute_import + +from requests.compat import json as _json +from requests.exceptions import ConnectionError + +from plotly import version +from plotly.api.utils import to_native_utf8_string +from plotly.api.v2 import utils +from plotly.exceptions import PlotlyRequestError +from plotly.session import sign_in +from plotly.tests.test_core.test_api import PlotlyApiTestCase + + +class MakeParamsTest(PlotlyApiTestCase): + + def test_make_params(self): + params = utils.make_params(foo='FOO', bar=None) + self.assertEqual(params, {'foo': 'FOO'}) + + def test_make_params_empty(self): + params = utils.make_params(foo=None, bar=None) + self.assertEqual(params, {}) + + +class BuildUrlTest(PlotlyApiTestCase): + + def test_build_url(self): + url = utils.build_url('cats') + self.assertEqual(url, '{}/v2/cats'.format(self.plotly_api_domain)) + + def test_build_url_id(self): + url = utils.build_url('cats', id='MsKitty') + self.assertEqual( + url, '{}/v2/cats/MsKitty'.format(self.plotly_api_domain) + ) + + def test_build_url_route(self): + url = utils.build_url('cats', route='about') + self.assertEqual( + url, '{}/v2/cats/about'.format(self.plotly_api_domain) + ) + + def test_build_url_id_route(self): + url = utils.build_url('cats', id='MsKitty', route='de-claw') + self.assertEqual( + url, '{}/v2/cats/MsKitty/de-claw'.format(self.plotly_api_domain) + ) + + +class ValidateResponseTest(PlotlyApiTestCase): + + def test_validate_ok(self): + try: + utils.validate_response(self.get_response()) + except PlotlyRequestError: + self.fail('Expected this to pass!') + + def test_validate_not_ok(self): + bad_status_codes = (400, 404, 500) + for bad_status_code in bad_status_codes: + response = self.get_response(status_code=bad_status_code) + self.assertRaises(PlotlyRequestError, utils.validate_response, + response) + + def test_validate_no_content(self): + + # We shouldn't flake if the response has no content. + + response = self.get_response(content=b'', status_code=400) + try: + utils.validate_response(response) + except PlotlyRequestError as e: + self.assertEqual(e.message, u'No Content') + self.assertEqual(e.status_code, 400) + self.assertEqual(e.content.decode('utf-8'), u'') + else: + self.fail('Expected this to raise!') + + def test_validate_non_json_content(self): + response = self.get_response(content=b'foobar', status_code=400) + try: + utils.validate_response(response) + except PlotlyRequestError as e: + self.assertEqual(e.message, 'foobar') + self.assertEqual(e.status_code, 400) + self.assertEqual(e.content, b'foobar') + else: + self.fail('Expected this to raise!') + + def test_validate_json_content_array(self): + content = self.to_bytes(_json.dumps([1, 2, 3])) + response = self.get_response(content=content, status_code=400) + try: + utils.validate_response(response) + except PlotlyRequestError as e: + self.assertEqual(e.message, to_native_utf8_string(content)) + self.assertEqual(e.status_code, 400) + self.assertEqual(e.content, content) + else: + self.fail('Expected this to raise!') + + def test_validate_json_content_dict_no_errors(self): + content = self.to_bytes(_json.dumps({'foo': 'bar'})) + response = self.get_response(content=content, status_code=400) + try: + utils.validate_response(response) + except PlotlyRequestError as e: + self.assertEqual(e.message, to_native_utf8_string(content)) + self.assertEqual(e.status_code, 400) + self.assertEqual(e.content, content) + else: + self.fail('Expected this to raise!') + + def test_validate_json_content_dict_one_error_bad(self): + content = self.to_bytes(_json.dumps({'errors': [{}]})) + response = self.get_response(content=content, status_code=400) + try: + utils.validate_response(response) + except PlotlyRequestError as e: + self.assertEqual(e.message, to_native_utf8_string(content)) + self.assertEqual(e.status_code, 400) + self.assertEqual(e.content, content) + else: + self.fail('Expected this to raise!') + + content = self.to_bytes(_json.dumps({'errors': [{'message': ''}]})) + response = self.get_response(content=content, status_code=400) + try: + utils.validate_response(response) + except PlotlyRequestError as e: + self.assertEqual(e.message, to_native_utf8_string(content)) + self.assertEqual(e.status_code, 400) + self.assertEqual(e.content, content) + else: + self.fail('Expected this to raise!') + + def test_validate_json_content_dict_one_error_ok(self): + content = self.to_bytes(_json.dumps( + {'errors': [{'message': 'not ok!'}]})) + response = self.get_response(content=content, status_code=400) + try: + utils.validate_response(response) + except PlotlyRequestError as e: + self.assertEqual(e.message, 'not ok!') + self.assertEqual(e.status_code, 400) + self.assertEqual(e.content, content) + else: + self.fail('Expected this to raise!') + + def test_validate_json_content_dict_multiple_errors(self): + content = self.to_bytes(_json.dumps({'errors': [ + {'message': 'not ok!'}, {'message': 'bad job...'} + ]})) + response = self.get_response(content=content, status_code=400) + try: + utils.validate_response(response) + except PlotlyRequestError as e: + self.assertEqual(e.message, 'not ok!\nbad job...') + self.assertEqual(e.status_code, 400) + self.assertEqual(e.content, content) + else: + self.fail('Expected this to raise!') + + +class GetHeadersTest(PlotlyApiTestCase): + + def test_normal_auth(self): + headers = utils.get_headers() + expected_headers = { + 'plotly-client-platform': 'python {}'.format(version.__version__), + 'authorization': 'Basic Zm9vOmJhcg==', + 'content-type': 'application/json' + } + self.assertEqual(headers, expected_headers) + + def test_proxy_auth(self): + sign_in(self.username, self.api_key, plotly_proxy_authorization=True) + headers = utils.get_headers() + expected_headers = { + 'plotly-client-platform': 'python {}'.format(version.__version__), + 'authorization': 'Basic Y25ldDpob29wbGE=', + 'plotly-authorization': 'Basic Zm9vOmJhcg==', + 'content-type': 'application/json' + } + self.assertEqual(headers, expected_headers) + + +class RequestTest(PlotlyApiTestCase): + + def setUp(self): + super(RequestTest, self).setUp() + + # Mock the actual api call, we don't want to do network tests here. + self.request_mock = self.mock('plotly.api.v2.utils.requests.request') + self.request_mock.return_value = self.get_response() + + # Mock the validation function since we can test that elsewhere. + self.validate_response_mock = self.mock( + 'plotly.api.v2.utils.validate_response') + + self.method = 'get' + self.url = 'https://foo.bar.does.not.exist.anywhere' + + def test_request_with_params(self): + + # urlencode transforms `True` --> `'True'`, which isn't super helpful, + # Our backend accepts the JS `true`, so we want `True` --> `'true'`. + + params = {'foo': True, 'bar': 'True', 'baz': False, 'zap': 0} + utils.request(self.method, self.url, params=params) + args, kwargs = self.request_mock.call_args + method, url = args + expected_params = {'foo': 'true', 'bar': 'True', 'baz': 'false', + 'zap': 0} + self.assertEqual(method, self.method) + self.assertEqual(url, self.url) + self.assertEqual(kwargs['params'], expected_params) + + def test_request_with_non_native_objects(self): + + # We always send along json, but it may contain non-native objects like + # a pandas array or a Column reference. Make sure that's handled in one + # central place. + + class Duck(object): + def to_plotly_json(self): + return 'what else floats?' + + utils.request(self.method, self.url, json={'foo': [Duck(), Duck()]}) + args, kwargs = self.request_mock.call_args + method, url = args + expected_data = '{"foo": ["what else floats?", "what else floats?"]}' + self.assertEqual(method, self.method) + self.assertEqual(url, self.url) + self.assertEqual(kwargs['data'], expected_data) + self.assertNotIn('json', kwargs) + + def test_request_with_ConnectionError(self): + + # requests can flake out and not return a response object, we want to + # make sure we remain consistent with our errors. + + self.request_mock.side_effect = ConnectionError() + self.assertRaises(PlotlyRequestError, utils.request, self.method, + self.url) + + def test_request_validate_response(self): + + # Finally, we check details elsewhere, but make sure we do validate. + + utils.request(self.method, self.url) + self.validate_response_mock.assert_called_once() diff --git a/plotly/tests/test_core/test_file/test_file.py b/plotly/tests/test_core/test_file/test_file.py index cd4ea5b9a47..c8b3bb8680a 100644 --- a/plotly/tests/test_core/test_file/test_file.py +++ b/plotly/tests/test_core/test_file/test_file.py @@ -49,7 +49,7 @@ def test_duplicate_folders(self): py.file_ops.mkdirs(first_folder) try: py.file_ops.mkdirs(first_folder) - except requests.exceptions.RequestException as e: - self.assertTrue(400 <= e.response.status_code < 500) + except PlotlyRequestError as e: + self.assertTrue(400 <= e.status_code < 500) else: self.fail('Expected this to fail!') diff --git a/plotly/tests/test_core/test_get_requests/test_get_requests.py b/plotly/tests/test_core/test_get_requests/test_get_requests.py index 4c4fd939e47..1719d86b38d 100644 --- a/plotly/tests/test_core/test_get_requests/test_get_requests.py +++ b/plotly/tests/test_core/test_get_requests/test_get_requests.py @@ -6,11 +6,11 @@ """ import copy -import json -import requests +import requests import six from nose.plugins.attrib import attr +from requests.compat import json as _json default_headers = {'plotly-username': '', @@ -37,9 +37,9 @@ def test_user_does_not_exist(): resource = "/apigetfile/{0}/{1}/".format(file_owner, file_id) response = requests.get(server + resource, headers=hd) if six.PY3: - content = json.loads(response.content.decode('unicode_escape')) + content = _json.loads(response.content.decode('unicode_escape')) else: - content = json.loads(response.content) + content = _json.loads(response.content) print(response.status_code) print(content) assert response.status_code == 404 @@ -60,9 +60,9 @@ def test_file_does_not_exist(): resource = "/apigetfile/{0}/{1}/".format(file_owner, file_id) response = requests.get(server + resource, headers=hd) if six.PY3: - content = json.loads(response.content.decode('unicode_escape')) + content = _json.loads(response.content.decode('unicode_escape')) else: - content = json.loads(response.content) + content = _json.loads(response.content) print(response.status_code) print(content) assert response.status_code == 404 @@ -100,9 +100,9 @@ def test_private_permission_defined(): resource = "/apigetfile/{0}/{1}/".format(file_owner, file_id) response = requests.get(server + resource, headers=hd) if six.PY3: - content = json.loads(response.content.decode('unicode_escape')) + content = _json.loads(response.content.decode('unicode_escape')) else: - content = json.loads(response.content) + content = _json.loads(response.content) print(response.status_code) print(content) assert response.status_code == 403 @@ -122,9 +122,9 @@ def test_missing_headers(): del hd[header] response = requests.get(server + resource, headers=hd) if six.PY3: - content = json.loads(response.content.decode('unicode_escape')) + content = _json.loads(response.content.decode('unicode_escape')) else: - content = json.loads(response.content) + content = _json.loads(response.content) print(response.status_code) print(content) assert response.status_code == 422 @@ -142,13 +142,13 @@ def test_valid_request(): resource = "/apigetfile/{0}/{1}/".format(file_owner, file_id) response = requests.get(server + resource, headers=hd) if six.PY3: - content = json.loads(response.content.decode('unicode_escape')) + content = _json.loads(response.content.decode('unicode_escape')) else: - content = json.loads(response.content) + content = _json.loads(response.content) print(response.status_code) print(content) assert response.status_code == 200 - # content = json.loads(res.content) + # content = _json.loads(res.content) # response_payload = content['payload'] # figure = response_payload['figure'] # if figure['data'][0]['x'] != [u'1', u'2', u'3']: diff --git a/plotly/tests/test_core/test_graph_reference/test_graph_reference.py b/plotly/tests/test_core/test_graph_reference/test_graph_reference.py index f3f97c7c5a8..005d33a7c05 100644 --- a/plotly/tests/test_core/test_graph_reference/test_graph_reference.py +++ b/plotly/tests/test_core/test_graph_reference/test_graph_reference.py @@ -4,36 +4,31 @@ """ from __future__ import absolute_import -import json import os from pkg_resources import resource_string from unittest import TestCase -import requests -import six from nose.plugins.attrib import attr +from requests.compat import json as _json -from plotly import files, graph_reference as gr +from plotly import graph_reference as gr +from plotly.api import v2 from plotly.graph_reference import string_to_class_name, get_role from plotly.tests.utils import PlotlyTestCase +FAKE_API_DOMAIN = 'https://api.am.not.here.ly' + class TestGraphReferenceCaching(PlotlyTestCase): @attr('slow') def test_default_schema_is_up_to_date(self): - api_domain = files.FILE_CONTENT[files.CONFIG_FILE]['plotly_api_domain'] - graph_reference_url = '{}{}?sha1'.format(api_domain, '/v2/plot-schema') - response = requests.get(graph_reference_url) - if six.PY3: - content = str(response.content, encoding='utf-8') - else: - content = response.content - schema = json.loads(content)['schema'] + response = v2.plot_schema.retrieve('') + schema = response.json()['schema'] path = os.path.join('package_data', 'default-schema.json') s = resource_string('plotly', path).decode('utf-8') - default_schema = json.loads(s) + default_schema = _json.loads(s) msg = ( 'The default, hard-coded plot schema we ship with pip is out of ' diff --git a/plotly/tests/test_core/test_grid/test_grid.py b/plotly/tests/test_core/test_grid/test_grid.py index d711fd8fd73..4ccd3690136 100644 --- a/plotly/tests/test_core/test_grid/test_grid.py +++ b/plotly/tests/test_core/test_grid/test_grid.py @@ -9,7 +9,6 @@ import random import string -import requests from nose import with_setup from nose.plugins.attrib import attr @@ -17,10 +16,10 @@ from unittest import skip import plotly.plotly as py -from plotly.exceptions import InputError, PlotlyRequestError +from plotly.exceptions import InputError, PlotlyRequestError, PlotlyError from plotly.graph_objs import Scatter from plotly.grid_objs import Column, Grid -from plotly.plotly.plotly import _api_v2 +from plotly.plotly.plotly import parse_grid_id_args def random_filename(): @@ -124,19 +123,18 @@ def test_get_figure_from_references(): def test_grid_id_args(): assert( - _api_v2.parse_grid_id_args(_grid, None) == - _api_v2.parse_grid_id_args(None, _grid_url) + parse_grid_id_args(_grid, None) == parse_grid_id_args(None, _grid_url) ) @raises(InputError) def test_no_grid_id_args(): - _api_v2.parse_grid_id_args(None, None) + parse_grid_id_args(None, None) @raises(InputError) def test_overspecified_grid_args(): - _api_v2.parse_grid_id_args(_grid, _grid_url) + parse_grid_id_args(_grid, _grid_url) # Out of order usage @@ -149,8 +147,7 @@ def test_scatter_from_non_uploaded_grid(): Scatter(xsrc=g[0], ysrc=g[1]) -@attr('slow') -@raises(requests.exceptions.HTTPError) +@raises(PlotlyError) def test_column_append_of_non_uploaded_grid(): c1 = Column([1, 2, 3, 4], 'first column') c2 = Column(['a', 'b', 'c', 'd'], 'second column') @@ -158,8 +155,7 @@ def test_column_append_of_non_uploaded_grid(): py.grid_ops.append_columns([c2], grid=g) -@attr('slow') -@raises(requests.exceptions.HTTPError) +@raises(PlotlyError) def test_row_append_of_non_uploaded_grid(): c1 = Column([1, 2, 3, 4], 'first column') rows = [[1], [2]] diff --git a/plotly/tests/test_core/test_offline/test_offline.py b/plotly/tests/test_core/test_offline/test_offline.py index f845709a287..cc4b903de71 100644 --- a/plotly/tests/test_core/test_offline/test_offline.py +++ b/plotly/tests/test_core/test_offline/test_offline.py @@ -4,12 +4,13 @@ """ from __future__ import absolute_import -from nose.tools import raises from unittest import TestCase -from plotly.tests.utils import PlotlyTestCase -import json + +from requests.compat import json as _json import plotly +from plotly.tests.utils import PlotlyTestCase + fig = { 'data': [ @@ -35,8 +36,9 @@ def _read_html(self, file_url): return f.read() def test_default_plot_generates_expected_html(self): - data_json = json.dumps(fig['data'], cls=plotly.utils.PlotlyJSONEncoder) - layout_json = json.dumps( + data_json = _json.dumps(fig['data'], + cls=plotly.utils.PlotlyJSONEncoder) + layout_json = _json.dumps( fig['layout'], cls=plotly.utils.PlotlyJSONEncoder) diff --git a/plotly/tests/test_core/test_optional_imports/__init__.py b/plotly/tests/test_core/test_optional_imports/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/plotly/tests/test_core/test_optional_imports/test_optional_imports.py b/plotly/tests/test_core/test_optional_imports/test_optional_imports.py new file mode 100644 index 00000000000..e7569f1609d --- /dev/null +++ b/plotly/tests/test_core/test_optional_imports/test_optional_imports.py @@ -0,0 +1,24 @@ +from __future__ import absolute_import + +from unittest import TestCase + +from plotly.optional_imports import get_module + + +class OptionalImportsTest(TestCase): + + def test_get_module_exists(self): + import math + module = get_module('math') + self.assertIsNotNone(module) + self.assertEqual(math, module) + + def test_get_module_exists_submodule(self): + import requests.sessions + module = get_module('requests.sessions') + self.assertIsNotNone(module) + self.assertEqual(requests.sessions, module) + + def test_get_module_does_not_exist(self): + module = get_module('hoopla') + self.assertIsNone(module) diff --git a/plotly/tests/test_core/test_plotly/test_credentials.py b/plotly/tests/test_core/test_plotly/test_credentials.py index 573b30c91b0..73b9eca8767 100644 --- a/plotly/tests/test_core/test_plotly/test_credentials.py +++ b/plotly/tests/test_core/test_plotly/test_credentials.py @@ -1,39 +1,43 @@ from __future__ import absolute_import -from unittest import TestCase +from mock import patch import plotly.plotly.plotly as py import plotly.session as session import plotly.tools as tls +from plotly import exceptions +from plotly.tests.utils import PlotlyTestCase -def test_get_credentials(): - session_credentials = session.get_session_credentials() - if 'username' in session_credentials: - del session._session['credentials']['username'] - if 'api_key' in session_credentials: - del session._session['credentials']['api_key'] - creds = py.get_credentials() - file_creds = tls.get_credentials_file() - print(creds) - print(file_creds) - assert creds == file_creds +class TestSignIn(PlotlyTestCase): + def setUp(self): + super(TestSignIn, self).setUp() + patcher = patch('plotly.api.v2.users.current') + self.users_current_mock = patcher.start() + self.addCleanup(patcher.stop) -def test_sign_in(): - un = 'anyone' - ak = 'something' - # TODO, add this! - # si = ['this', 'and-this'] - py.sign_in(un, ak) - creds = py.get_credentials() - assert creds['username'] == un - assert creds['api_key'] == ak - # TODO, and check it! - # assert creds['stream_ids'] == si + def test_get_credentials(self): + session_credentials = session.get_session_credentials() + if 'username' in session_credentials: + del session._session['credentials']['username'] + if 'api_key' in session_credentials: + del session._session['credentials']['api_key'] + creds = py.get_credentials() + file_creds = tls.get_credentials_file() + self.assertEqual(creds, file_creds) - -class TestSignIn(TestCase): + def test_sign_in(self): + un = 'anyone' + ak = 'something' + # TODO, add this! + # si = ['this', 'and-this'] + py.sign_in(un, ak) + creds = py.get_credentials() + self.assertEqual(creds['username'], un) + self.assertEqual(creds['api_key'], ak) + # TODO, and check it! + # assert creds['stream_ids'] == si def test_get_config(self): plotly_domain = 'test domain' @@ -74,3 +78,10 @@ def test_sign_in_with_config(self): self.assertEqual( config['plotly_ssl_verification'], plotly_ssl_verification ) + + def test_sign_in_cannot_validate(self): + self.users_current_mock.side_effect = exceptions.PlotlyRequestError( + 'msg', 400, 'foobar' + ) + with self.assertRaisesRegexp(exceptions.PlotlyError, 'Sign in failed'): + py.sign_in('foo', 'bar') diff --git a/plotly/tests/test_core/test_plotly/test_plot.py b/plotly/tests/test_core/test_plotly/test_plot.py index 25b6d208aa3..ef7e797cf58 100644 --- a/plotly/tests/test_core/test_plotly/test_plot.py +++ b/plotly/tests/test_core/test_plotly/test_plot.py @@ -7,11 +7,12 @@ """ from __future__ import absolute_import -import json import requests import six +from requests.compat import json as _json from unittest import TestCase +from mock import patch from nose.plugins.attrib import attr from nose.tools import raises @@ -40,9 +41,14 @@ def test_plot_valid(self): 'x': [1, 2, 3], 'y': [2, 1, 2] } - ] + ], + 'layout': {'title': 'simple'} } - py.plot(fig, auto_open=False, filename='plot_valid') + url = py.plot(fig, auto_open=False, filename='plot_valid') + saved_fig = py.get_figure(url) + self.assertEqual(saved_fig['data'][0]['x'], fig['data'][0]['x']) + self.assertEqual(saved_fig['data'][0]['y'], fig['data'][0]['y']) + self.assertEqual(saved_fig['layout']['title'], fig['layout']['title']) @raises(PlotlyError) def test_plot_invalid(self): @@ -223,6 +229,14 @@ class TestPlotOptionLogic(PlotlyTestCase): {'world_readable': False, 'sharing': 'public'} ) + def setUp(self): + super(TestPlotOptionLogic, self).setUp() + + # Make sure we don't hit sign-in validation failures. + patcher = patch('plotly.api.v2.users.current') + self.users_current_mock = patcher.start() + self.addCleanup(patcher.stop) + def test_default_options(self): options = py._plot_option_logic({}) config_options = tls.get_config_file() @@ -296,10 +310,10 @@ def generate_conflicting_plot_options_with_json_writes_of_config(): """ def gen_test(plot_options): def test(self): - config = json.load(open(CONFIG_FILE)) + config = _json.load(open(CONFIG_FILE)) with open(CONFIG_FILE, 'w') as f: config.update(plot_options) - f.write(json.dumps(config)) + f.write(_json.dumps(config)) self.assertRaises(PlotlyError, py._plot_option_logic, {}) return test diff --git a/plotly/tests/test_core/test_utils/test_utils.py b/plotly/tests/test_core/test_utils/test_utils.py index cb38648b8b6..b406a6464ab 100644 --- a/plotly/tests/test_core/test_utils/test_utils.py +++ b/plotly/tests/test_core/test_utils/test_utils.py @@ -1,8 +1,9 @@ from __future__ import absolute_import -import json from unittest import TestCase +from requests.compat import json as _json + from plotly.utils import PlotlyJSONEncoder, get_by_path, node_generator @@ -10,7 +11,7 @@ class TestJSONEncoder(TestCase): def test_nan_to_null(self): array = [1, float('NaN'), float('Inf'), float('-Inf'), 'platypus'] - result = json.dumps(array, cls=PlotlyJSONEncoder) + result = _json.dumps(array, cls=PlotlyJSONEncoder) expected_result = '[1, null, null, null, "platypus"]' self.assertEqual(result, expected_result) diff --git a/plotly/tests/test_optional/optional_utils.py b/plotly/tests/test_optional/optional_utils.py index 74308de7f53..76941b1b1fe 100644 --- a/plotly/tests/test_optional/optional_utils.py +++ b/plotly/tests/test_optional/optional_utils.py @@ -2,12 +2,13 @@ import numpy as np +from plotly import optional_imports from plotly.tests.utils import is_num_list from plotly.utils import get_by_path, node_generator -# TODO: matplotlib-build-wip -from plotly.tools import _matplotlylib_imported -if _matplotlylib_imported: +matplotlylib = optional_imports.get_module('plotly.matplotlylib') + +if matplotlylib: import matplotlib # Force matplotlib to not use any Xwindows backend. matplotlib.use('Agg') diff --git a/plotly/tests/test_optional/test_matplotlylib/test_annotations.py b/plotly/tests/test_optional/test_matplotlylib/test_annotations.py index d4987d01b70..b0e238416b9 100644 --- a/plotly/tests/test_optional/test_matplotlylib/test_annotations.py +++ b/plotly/tests/test_optional/test_matplotlylib/test_annotations.py @@ -2,10 +2,11 @@ from nose.plugins.attrib import attr -# TODO: matplotlib-build-wip -from plotly.tools import _matplotlylib_imported +from plotly import optional_imports -if _matplotlylib_imported: +matplotlylib = optional_imports.get_module('plotly.matplotlylib') + +if matplotlylib: import matplotlib # Force matplotlib to not use any Xwindows backend. matplotlib.use('Agg') diff --git a/plotly/tests/test_optional/test_matplotlylib/test_axis_scales.py b/plotly/tests/test_optional/test_matplotlylib/test_axis_scales.py index 277b886d9fd..2b62f28ceab 100644 --- a/plotly/tests/test_optional/test_matplotlylib/test_axis_scales.py +++ b/plotly/tests/test_optional/test_matplotlylib/test_axis_scales.py @@ -2,14 +2,14 @@ from nose.plugins.attrib import attr +from plotly import optional_imports from plotly.tests.utils import compare_dict from plotly.tests.test_optional.optional_utils import run_fig from plotly.tests.test_optional.test_matplotlylib.data.axis_scales import * -# TODO: matplotlib-build-wip -from plotly.tools import _matplotlylib_imported +matplotlylib = optional_imports.get_module('plotly.matplotlylib') -if _matplotlylib_imported: +if matplotlylib: import matplotlib # Force matplotlib to not use any Xwindows backend. matplotlib.use('Agg') diff --git a/plotly/tests/test_optional/test_matplotlylib/test_bars.py b/plotly/tests/test_optional/test_matplotlylib/test_bars.py index feb439dd2c8..c89bfb3a8fc 100644 --- a/plotly/tests/test_optional/test_matplotlylib/test_bars.py +++ b/plotly/tests/test_optional/test_matplotlylib/test_bars.py @@ -2,15 +2,15 @@ from nose.plugins.attrib import attr +from plotly import optional_imports from plotly.tests.utils import compare_dict from plotly.tests.test_optional.optional_utils import run_fig from plotly.tests.test_optional.test_matplotlylib.data.bars import * -# TODO: matplotlib-build-wip -from plotly.tools import _matplotlylib_imported -if _matplotlylib_imported: - import matplotlib +matplotlylib = optional_imports.get_module('plotly.matplotlylib') +if matplotlylib: + import matplotlib # Force matplotlib to not use any Xwindows backend. matplotlib.use('Agg') import matplotlib.pyplot as plt diff --git a/plotly/tests/test_optional/test_matplotlylib/test_data.py b/plotly/tests/test_optional/test_matplotlylib/test_data.py index 33ad00543a3..b188dbc2d46 100644 --- a/plotly/tests/test_optional/test_matplotlylib/test_data.py +++ b/plotly/tests/test_optional/test_matplotlylib/test_data.py @@ -2,12 +2,13 @@ from nose.plugins.attrib import attr +from plotly import optional_imports from plotly.tests.test_optional.optional_utils import run_fig from plotly.tests.test_optional.test_matplotlylib.data.data import * -# TODO: matplotlib-build-wip -from plotly.tools import _matplotlylib_imported -if _matplotlylib_imported: +matplotlylib = optional_imports.get_module('plotly.matplotlylib') + +if matplotlylib: import matplotlib # Force matplotlib to not use any Xwindows backend. diff --git a/plotly/tests/test_optional/test_matplotlylib/test_date_times.py b/plotly/tests/test_optional/test_matplotlylib/test_date_times.py index 0c56efd5ad6..5f1bf59aefd 100644 --- a/plotly/tests/test_optional/test_matplotlylib/test_date_times.py +++ b/plotly/tests/test_optional/test_matplotlylib/test_date_times.py @@ -8,10 +8,11 @@ from nose.plugins.attrib import attr import plotly.tools as tls +from plotly import optional_imports -# TODO: matplotlib-build-wip -from plotly.tools import _matplotlylib_imported -if _matplotlylib_imported: +matplotlylib = optional_imports.get_module('plotly.matplotlylib') + +if matplotlylib: import matplotlib # Force matplotlib to not use any Xwindows backend. diff --git a/plotly/tests/test_optional/test_matplotlylib/test_lines.py b/plotly/tests/test_optional/test_matplotlylib/test_lines.py index 9d1074bc571..b7355cccc22 100644 --- a/plotly/tests/test_optional/test_matplotlylib/test_lines.py +++ b/plotly/tests/test_optional/test_matplotlylib/test_lines.py @@ -2,13 +2,14 @@ from nose.plugins.attrib import attr +from plotly import optional_imports from plotly.tests.utils import compare_dict from plotly.tests.test_optional.optional_utils import run_fig from plotly.tests.test_optional.test_matplotlylib.data.lines import * -# TODO: matplotlib-build-wip -from plotly.tools import _matplotlylib_imported -if _matplotlylib_imported: +matplotlylib = optional_imports.get_module('plotly.matplotlylib') + +if matplotlylib: import matplotlib # Force matplotlib to not use any Xwindows backend. diff --git a/plotly/tests/test_optional/test_matplotlylib/test_scatter.py b/plotly/tests/test_optional/test_matplotlylib/test_scatter.py index 07fc276f5fb..80774c079eb 100644 --- a/plotly/tests/test_optional/test_matplotlylib/test_scatter.py +++ b/plotly/tests/test_optional/test_matplotlylib/test_scatter.py @@ -2,13 +2,14 @@ from nose.plugins.attrib import attr +from plotly import optional_imports from plotly.tests.utils import compare_dict from plotly.tests.test_optional.optional_utils import run_fig from plotly.tests.test_optional.test_matplotlylib.data.scatter import * -# TODO: matplotlib-build-wip -from plotly.tools import _matplotlylib_imported -if _matplotlylib_imported: +matplotlylib = optional_imports.get_module('plotly.matplotlylib') + +if matplotlylib: import matplotlib # Force matplotlib to not use any Xwindows backend. diff --git a/plotly/tests/test_optional/test_matplotlylib/test_subplots.py b/plotly/tests/test_optional/test_matplotlylib/test_subplots.py index 725ca8aa78b..c30ae23bbf9 100644 --- a/plotly/tests/test_optional/test_matplotlylib/test_subplots.py +++ b/plotly/tests/test_optional/test_matplotlylib/test_subplots.py @@ -2,13 +2,14 @@ from nose.plugins.attrib import attr +from plotly import optional_imports from plotly.tests.utils import compare_dict from plotly.tests.test_optional.optional_utils import run_fig from plotly.tests.test_optional.test_matplotlylib.data.subplots import * -# TODO: matplotlib-build-wip -from plotly.tools import _matplotlylib_imported -if _matplotlylib_imported: +matplotlylib = optional_imports.get_module('plotly.matplotlylib') + +if matplotlylib: import matplotlib # Force matplotlib to not use any Xwindows backend. diff --git a/plotly/tests/test_optional/test_offline/test_offline.py b/plotly/tests/test_optional/test_offline/test_offline.py index 93d2c4c3770..17dd1af2bdd 100644 --- a/plotly/tests/test_optional/test_offline/test_offline.py +++ b/plotly/tests/test_optional/test_offline/test_offline.py @@ -6,16 +6,16 @@ from nose.tools import raises from nose.plugins.attrib import attr +from requests.compat import json as _json from unittest import TestCase -import json import plotly +from plotly import optional_imports -# TODO: matplotlib-build-wip -from plotly.tools import _matplotlylib_imported +matplotlylib = optional_imports.get_module('plotly.matplotlylib') -if _matplotlylib_imported: +if matplotlylib: import matplotlib # Force matplotlib to not use any Xwindows backend. matplotlib.use('Agg') @@ -34,7 +34,6 @@ def test_iplot_doesnt_work_before_you_call_init_notebook_mode(self): plotly.offline.iplot([{}]) def test_iplot_works_after_you_call_init_notebook_mode(self): - plotly.tools._ipython_imported = True plotly.offline.init_notebook_mode() plotly.offline.iplot([{}]) @@ -47,7 +46,6 @@ def test_iplot_mpl_works_after_you_call_init_notebook_mode(self): y = [100, 200, 300] plt.plot(x, y, "o") - plotly.tools._ipython_imported = True plotly.offline.init_notebook_mode() plotly.offline.iplot_mpl(fig) @@ -75,8 +73,8 @@ def test_default_mpl_plot_generates_expected_html(self): figure = plotly.tools.mpl_to_plotly(fig) data = figure['data'] layout = figure['layout'] - data_json = json.dumps(data, cls=plotly.utils.PlotlyJSONEncoder) - layout_json = json.dumps(layout, cls=plotly.utils.PlotlyJSONEncoder) + data_json = _json.dumps(data, cls=plotly.utils.PlotlyJSONEncoder) + layout_json = _json.dumps(layout, cls=plotly.utils.PlotlyJSONEncoder) html = self._read_html(plotly.offline.plot_mpl(fig)) # just make sure a few of the parts are in here diff --git a/plotly/tests/test_optional/test_plotly/test_plot_mpl.py b/plotly/tests/test_optional/test_plotly/test_plot_mpl.py index e509d59d82a..876f7f4b9b3 100644 --- a/plotly/tests/test_optional/test_plotly/test_plot_mpl.py +++ b/plotly/tests/test_optional/test_plotly/test_plot_mpl.py @@ -10,13 +10,13 @@ from nose.plugins.attrib import attr from nose.tools import raises -from plotly import exceptions +from plotly import exceptions, optional_imports from plotly.plotly import plotly as py from unittest import TestCase -# TODO: matplotlib-build-wip -from plotly.tools import _matplotlylib_imported -if _matplotlylib_imported: +matplotlylib = optional_imports.get_module('plotly.matplotlylib') + +if matplotlylib: import matplotlib # Force matplotlib to not use any Xwindows backend. diff --git a/plotly/tests/test_optional/test_tools/__init__.py b/plotly/tests/test_optional/test_tools/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/plotly/tests/test_core/test_tools/test_figure_factory.py b/plotly/tests/test_optional/test_tools/test_figure_factory.py similarity index 99% rename from plotly/tests/test_core/test_tools/test_figure_factory.py rename to plotly/tests/test_optional/test_tools/test_figure_factory.py index 9ebc68c55f7..170920be197 100644 --- a/plotly/tests/test_core/test_tools/test_figure_factory.py +++ b/plotly/tests/test_optional/test_tools/test_figure_factory.py @@ -1649,8 +1649,6 @@ def test_2D_density_all_args(self): # def test_scipy_import_error(self): -# # make sure Import Error is raised when _scipy_imported = False - # hist_data = [[1.1, 1.1, 2.5, 3.0, 3.5, # 3.5, 4.1, 4.4, 4.5, 4.5, # 5.0, 5.0, 5.2, 5.5, 5.5, diff --git a/plotly/tests/test_optional/test_utils/test_utils.py b/plotly/tests/test_optional/test_utils/test_utils.py index a72f89a8b60..9531e12ad88 100644 --- a/plotly/tests/test_optional/test_utils/test_utils.py +++ b/plotly/tests/test_optional/test_utils/test_utils.py @@ -5,7 +5,6 @@ from __future__ import absolute_import import datetime -import json import math import decimal from datetime import datetime as dt @@ -16,14 +15,15 @@ import pytz from nose.plugins.attrib import attr from pandas.util.testing import assert_series_equal +from requests.compat import json as _json -from plotly import utils +from plotly import optional_imports, utils from plotly.graph_objs import Scatter, Scatter3d, Figure, Data from plotly.grid_objs import Column -# TODO: matplotlib-build-wip -from plotly.tools import _matplotlylib_imported -if _matplotlylib_imported: +matplotlylib = optional_imports.get_module('plotly.matplotlylib') + +if matplotlylib: import matplotlib.pyplot as plt from plotly.matplotlylib import Exporter, PlotlyRenderer @@ -179,7 +179,7 @@ def test_column_json_encoding(): Column(mixed_list, 'col 2'), Column(np_list, 'col 3') ] - json_columns = json.dumps( + json_columns = _json.dumps( columns, cls=utils.PlotlyJSONEncoder, sort_keys=True ) assert('[{"data": [1, 2, 3], "name": "col 1"}, ' @@ -198,8 +198,8 @@ def test_figure_json_encoding(): data = Data([s1, s2]) figure = Figure(data=data) - js1 = json.dumps(s1, cls=utils.PlotlyJSONEncoder, sort_keys=True) - js2 = json.dumps(s2, cls=utils.PlotlyJSONEncoder, sort_keys=True) + js1 = _json.dumps(s1, cls=utils.PlotlyJSONEncoder, sort_keys=True) + js2 = _json.dumps(s2, cls=utils.PlotlyJSONEncoder, sort_keys=True) assert(js1 == '{"type": "scatter3d", "x": [1, 2, 3], ' '"y": [1, 2, 3, null, null, null, "2014-01-05"], ' @@ -208,8 +208,8 @@ def test_figure_json_encoding(): assert(js2 == '{"type": "scatter", "x": [1, 2, 3]}') # Test JSON encoding works - json.dumps(data, cls=utils.PlotlyJSONEncoder, sort_keys=True) - json.dumps(figure, cls=utils.PlotlyJSONEncoder, sort_keys=True) + _json.dumps(data, cls=utils.PlotlyJSONEncoder, sort_keys=True) + _json.dumps(figure, cls=utils.PlotlyJSONEncoder, sort_keys=True) # Test data wasn't mutated assert(bool(np.asarray(np_list == @@ -221,18 +221,18 @@ def test_figure_json_encoding(): def test_datetime_json_encoding(): - j1 = json.dumps(dt_list, cls=utils.PlotlyJSONEncoder) + j1 = _json.dumps(dt_list, cls=utils.PlotlyJSONEncoder) assert(j1 == '["2014-01-05", ' '"2014-01-05 01:01:01", ' '"2014-01-05 01:01:01.000001"]') - j2 = json.dumps({"x": dt_list}, cls=utils.PlotlyJSONEncoder) + j2 = _json.dumps({"x": dt_list}, cls=utils.PlotlyJSONEncoder) assert(j2 == '{"x": ["2014-01-05", ' '"2014-01-05 01:01:01", ' '"2014-01-05 01:01:01.000001"]}') def test_pandas_json_encoding(): - j1 = json.dumps(df['col 1'], cls=utils.PlotlyJSONEncoder) + j1 = _json.dumps(df['col 1'], cls=utils.PlotlyJSONEncoder) assert(j1 == '[1, 2, 3, "2014-01-05", null, null, null]') # Test that data wasn't mutated @@ -240,28 +240,28 @@ def test_pandas_json_encoding(): pd.Series([1, 2, 3, dt(2014, 1, 5), pd.NaT, np.NaN, np.Inf], name='col 1')) - j2 = json.dumps(df.index, cls=utils.PlotlyJSONEncoder) + j2 = _json.dumps(df.index, cls=utils.PlotlyJSONEncoder) assert(j2 == '[0, 1, 2, 3, 4, 5, 6]') nat = [pd.NaT] - j3 = json.dumps(nat, cls=utils.PlotlyJSONEncoder) + j3 = _json.dumps(nat, cls=utils.PlotlyJSONEncoder) assert(j3 == '[null]') assert(nat[0] is pd.NaT) - j4 = json.dumps(rng, cls=utils.PlotlyJSONEncoder) + j4 = _json.dumps(rng, cls=utils.PlotlyJSONEncoder) assert(j4 == '["2011-01-01", "2011-01-01 01:00:00"]') - j5 = json.dumps(ts, cls=utils.PlotlyJSONEncoder) + j5 = _json.dumps(ts, cls=utils.PlotlyJSONEncoder) assert(j5 == '[1.5, 2.5]') assert_series_equal(ts, pd.Series([1.5, 2.5], index=rng)) - j6 = json.dumps(ts.index, cls=utils.PlotlyJSONEncoder) + j6 = _json.dumps(ts.index, cls=utils.PlotlyJSONEncoder) assert(j6 == '["2011-01-01", "2011-01-01 01:00:00"]') def test_numpy_masked_json_encoding(): l = [1, 2, np.ma.core.masked] - j1 = json.dumps(l, cls=utils.PlotlyJSONEncoder) + j1 = _json.dumps(l, cls=utils.PlotlyJSONEncoder) print(j1) assert(j1 == '[1, 2, null]') @@ -285,18 +285,18 @@ def test_masked_constants_example(): renderer = PlotlyRenderer() Exporter(renderer).run(fig) - json.dumps(renderer.plotly_fig, cls=utils.PlotlyJSONEncoder) + _json.dumps(renderer.plotly_fig, cls=utils.PlotlyJSONEncoder) - jy = json.dumps(renderer.plotly_fig['data'][1]['y'], + jy = _json.dumps(renderer.plotly_fig['data'][1]['y'], cls=utils.PlotlyJSONEncoder) print(jy) - array = json.loads(jy) + array = _json.loads(jy) assert(array == [-398.11793027, -398.11792966, -398.11786308, None]) def test_numpy_dates(): a = np.arange(np.datetime64('2011-07-11'), np.datetime64('2011-07-18')) - j1 = json.dumps(a, cls=utils.PlotlyJSONEncoder) + j1 = _json.dumps(a, cls=utils.PlotlyJSONEncoder) assert(j1 == '["2011-07-11", "2011-07-12", "2011-07-13", ' '"2011-07-14", "2011-07-15", "2011-07-16", ' '"2011-07-17"]') @@ -304,5 +304,5 @@ def test_numpy_dates(): def test_datetime_dot_date(): a = [datetime.date(2014, 1, 1), datetime.date(2014, 1, 2)] - j1 = json.dumps(a, cls=utils.PlotlyJSONEncoder) + j1 = _json.dumps(a, cls=utils.PlotlyJSONEncoder) assert(j1 == '["2014-01-01", "2014-01-02"]') diff --git a/plotly/tests/utils.py b/plotly/tests/utils.py index f8b1438e6f1..2d1113d68b4 100644 --- a/plotly/tests/utils.py +++ b/plotly/tests/utils.py @@ -1,5 +1,4 @@ import copy -import json from numbers import Number as Num from unittest import TestCase diff --git a/plotly/tools.py b/plotly/tools.py index e43106a940d..c45d4cdabcd 100644 --- a/plotly/tools.py +++ b/plotly/tools.py @@ -8,19 +8,12 @@ """ from __future__ import absolute_import -from collections import OrderedDict import warnings import six -import math -import decimal - -from plotly import colors -from plotly import utils -from plotly import exceptions -from plotly import graph_reference -from plotly import session + +from plotly import exceptions, optional_imports, session, utils from plotly.files import (CONFIG_FILE, CREDENTIALS_FILE, FILE_CONTENT, check_file_permissions) @@ -63,55 +56,9 @@ def warning_on_one_line(message, category, filename, lineno, message) warnings.formatwarning = warning_on_one_line -try: - from . import matplotlylib - _matplotlylib_imported = True -except ImportError: - _matplotlylib_imported = False - -try: - import IPython - import IPython.core.display - _ipython_imported = True -except ImportError: - _ipython_imported = False - -try: - import numpy as np - _numpy_imported = True -except ImportError: - _numpy_imported = False - -try: - import pandas as pd - _pandas_imported = True -except ImportError: - _pandas_imported = False - -try: - import scipy as scp - _scipy_imported = True -except ImportError: - _scipy_imported = False - -try: - import scipy.spatial as scs - _scipy__spatial_imported = True -except ImportError: - _scipy__spatial_imported = False - -try: - import scipy.cluster.hierarchy as sch - _scipy__cluster__hierarchy_imported = True -except ImportError: - _scipy__cluster__hierarchy_imported = False - -try: - import scipy - import scipy.stats - _scipy_imported = True -except ImportError: - _scipy_imported = False +ipython_core_display = optional_imports.get_module('IPython.core.display') +matplotlylib = optional_imports.get_module('plotly.matplotlylib') +sage_salvus = optional_imports.get_module('sage_salvus') def get_config_defaults(): @@ -419,11 +366,11 @@ def embed(file_owner_or_url, file_id=None, width="100%", height=525): height=height) # see if we are in the SageMath Cloud - from sage_salvus import html - return html(s, hide=False) + if sage_salvus: + return sage_salvus.html(s, hide=False) except: pass - if _ipython_imported: + if ipython_core_display: if file_id: plotly_domain = ( session.get_session_config().get('plotly_domain') or @@ -502,7 +449,7 @@ def mpl_to_plotly(fig, resize=False, strip_style=False, verbose=False): {plotly_domain}/python/getting-started """ - if _matplotlylib_imported: + if matplotlylib: renderer = matplotlylib.PlotlyRenderer() matplotlylib.Exporter(renderer).run(fig) if resize: @@ -1357,6 +1304,7 @@ def validate(obj, obj_type): """ # TODO: Deprecate or move. #283 + from plotly import graph_reference from plotly.graph_objs import graph_objs if obj_type not in graph_reference.CLASSES: @@ -1397,8 +1345,8 @@ def _replace_newline(obj): return obj # we return the actual reference... but DON'T mutate. -if _ipython_imported: - class PlotlyDisplay(IPython.core.display.HTML): +if ipython_core_display: + class PlotlyDisplay(ipython_core_display.HTML): """An IPython display object for use with plotly urls PlotlyDisplay objects should be instantiated with a url for a plot. @@ -1459,6137 +1407,91 @@ def return_figure_from_figure_or_data(figure_or_data, validate_figure): class FigureFactory(object): - """ - BETA functions to create specific chart types. - - This is beta as in: subject to change in a backwards incompatible way - without notice. - - Supported chart types include candlestick, open high low close, quiver, - streamline, distplot, dendrogram, annotated heatmap, and tables. See - FigureFactory.create_candlestick, FigureFactory.create_ohlc, - FigureFactory.create_quiver, FigureFactory.create_streamline, - FigureFactory.create_distplot, FigureFactory.create_dendrogram, - FigureFactory.create_annotated_heatmap, or FigureFactory.create_table for - more information and examples of a specific chart type. - """ - - @staticmethod - def _make_colorscale(colors, scale=None): - """ - Makes a colorscale from a list of colors and scale - - Takes a list of colors and scales and constructs a colorscale based - on the colors in sequential order. If 'scale' is left empty, a linear- - interpolated colorscale will be generated. If 'scale' is a specificed - list, it must be the same legnth as colors and must contain all floats - For documentation regarding to the form of the output, see - https://plot.ly/python/reference/#mesh3d-colorscale - """ - colorscale = [] - - if not scale: - for j, color in enumerate(colors): - colorscale.append([j * 1./(len(colors) - 1), color]) - return colorscale - - else: - colorscale = [list(tup) for tup in zip(scale, colors)] - return colorscale - - @staticmethod - def _convert_colorscale_to_rgb(colorscale): - """ - Converts the colors in a colorscale to rgb colors - - A colorscale is an array of arrays, each with a numeric value as the - first item and a color as the second. This function specifically is - converting a colorscale with tuple colors (each coordinate between 0 - and 1) into a colorscale with the colors transformed into rgb colors - """ - for color in colorscale: - color[1] = FigureFactory._convert_to_RGB_255( - color[1] - ) - - for color in colorscale: - color[1] = FigureFactory._label_rgb( - color[1] - ) - return colorscale - - @staticmethod - def _make_linear_colorscale(colors): - """ - Makes a list of colors into a colorscale-acceptable form - - For documentation regarding to the form of the output, see - https://plot.ly/python/reference/#mesh3d-colorscale - """ - scale = 1./(len(colors) - 1) - return[[i * scale, color] for i, color in enumerate(colors)] - - @staticmethod - def create_2D_density(x, y, colorscale='Earth', ncontours=20, - hist_color=(0, 0, 0.5), point_color=(0, 0, 0.5), - point_size=2, title='2D Density Plot', - height=600, width=600): - """ - Returns figure for a 2D density plot - - :param (list|array) x: x-axis data for plot generation - :param (list|array) y: y-axis data for plot generation - :param (str|tuple|list) colorscale: either a plotly scale name, an rgb - or hex color, a color tuple or a list or tuple of colors. An rgb - color is of the form 'rgb(x, y, z)' where x, y, z belong to the - interval [0, 255] and a color tuple is a tuple of the form - (a, b, c) where a, b and c belong to [0, 1]. If colormap is a - list, it must contain the valid color types aforementioned as its - members. - :param (int) ncontours: the number of 2D contours to draw on the plot - :param (str) hist_color: the color of the plotted histograms - :param (str) point_color: the color of the scatter points - :param (str) point_size: the color of the scatter points - :param (str) title: set the title for the plot - :param (float) height: the height of the chart - :param (float) width: the width of the chart - - Example 1: Simple 2D Density Plot - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - - import numpy as np - - # Make data points - t = np.linspace(-1,1.2,2000) - x = (t**3)+(0.3*np.random.randn(2000)) - y = (t**6)+(0.3*np.random.randn(2000)) - - # Create a figure - fig = FF.create_2D_density(x, y) - - # Plot the data - py.iplot(fig, filename='simple-2d-density') - ``` - - Example 2: Using Parameters - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - - import numpy as np - - # Make data points - t = np.linspace(-1,1.2,2000) - x = (t**3)+(0.3*np.random.randn(2000)) - y = (t**6)+(0.3*np.random.randn(2000)) - - # Create custom colorscale - colorscale = ['#7A4579', '#D56073', 'rgb(236,158,105)', - (1, 1, 0.2), (0.98,0.98,0.98)] - - # Create a figure - fig = FF.create_2D_density( - x, y, colorscale=colorscale, - hist_color='rgb(255, 237, 222)', point_size=3) - - # Plot the data - py.iplot(fig, filename='use-parameters') - ``` - """ - from plotly.graph_objs import graph_objs - from numbers import Number - - # validate x and y are filled with numbers only - for array in [x, y]: - if not all(isinstance(element, Number) for element in array): - raise exceptions.PlotlyError( - "All elements of your 'x' and 'y' lists must be numbers." - ) - - # validate x and y are the same length - if len(x) != len(y): - raise exceptions.PlotlyError( - "Both lists 'x' and 'y' must be the same length." - ) - - colorscale = FigureFactory._validate_colors(colorscale, 'rgb') - colorscale = FigureFactory._make_linear_colorscale(colorscale) - - # validate hist_color and point_color - hist_color = FigureFactory._validate_colors(hist_color, 'rgb') - point_color = FigureFactory._validate_colors(point_color, 'rgb') - - trace1 = graph_objs.Scatter( - x=x, y=y, mode='markers', name='points', - marker=dict( - color=point_color[0], - size=point_size, - opacity=0.4 - ) - ) - trace2 = graph_objs.Histogram2dcontour( - x=x, y=y, name='density', ncontours=ncontours, - colorscale=colorscale, reversescale=True, showscale=False - ) - trace3 = graph_objs.Histogram( - x=x, name='x density', - marker=dict(color=hist_color[0]), yaxis='y2' - ) - trace4 = graph_objs.Histogram( - y=y, name='y density', - marker=dict(color=hist_color[0]), xaxis='x2' - ) - data = [trace1, trace2, trace3, trace4] - - layout = graph_objs.Layout( - showlegend=False, - autosize=False, - title=title, - height=height, - width=width, - xaxis=dict( - domain=[0, 0.85], - showgrid=False, - zeroline=False - ), - yaxis=dict( - domain=[0, 0.85], - showgrid=False, - zeroline=False - ), - margin=dict( - t=50 - ), - hovermode='closest', - bargap=0, - xaxis2=dict( - domain=[0.85, 1], - showgrid=False, - zeroline=False - ), - yaxis2=dict( - domain=[0.85, 1], - showgrid=False, - zeroline=False - ) - ) - - fig = graph_objs.Figure(data=data, layout=layout) - return fig - - @staticmethod - def _validate_gantt(df): - """ - Validates the inputted dataframe or list - """ - if _pandas_imported and isinstance(df, pd.core.frame.DataFrame): - # validate that df has all the required keys - for key in REQUIRED_GANTT_KEYS: - if key not in df: - raise exceptions.PlotlyError( - "The columns in your dataframe must include the " - "following keys: {0}".format(', '.join(REQUIRED_GANTT_KEYS)) - ) - - num_of_rows = len(df.index) - chart = [] - for index in range(num_of_rows): - task_dict = {} - for key in df: - task_dict[key] = df.ix[index][key] - chart.append(task_dict) - - return chart - - # validate if df is a list - if not isinstance(df, list): - raise exceptions.PlotlyError("You must input either a dataframe " - "or a list of dictionaries.") - - # validate if df is empty - if len(df) <= 0: - raise exceptions.PlotlyError("Your list is empty. It must contain " - "at least one dictionary.") - if not isinstance(df[0], dict): - raise exceptions.PlotlyError("Your list must only " - "include dictionaries.") - return df - - @staticmethod - def _gantt(chart, colors, title, bar_width, showgrid_x, showgrid_y, - height, width, tasks=None, task_names=None, data=None, group_tasks=False): - """ - Refer to FigureFactory.create_gantt() for docstring - """ - if tasks is None: - tasks = [] - if task_names is None: - task_names = [] - if data is None: - data = [] - - for index in range(len(chart)): - task = dict(x0=chart[index]['Start'], - x1=chart[index]['Finish'], - name=chart[index]['Task']) - if 'Description' in chart[index]: - task['description'] = chart[index]['Description'] - tasks.append(task) - - shape_template = { - 'type': 'rect', - 'xref': 'x', - 'yref': 'y', - 'opacity': 1, - 'line': { - 'width': 0, - }, - 'yref': 'y', - } - # create the list of task names - for index in range(len(tasks)): - tn = tasks[index]['name'] - # Is added to task_names if group_tasks is set to False, - # or if the option is used (True) it only adds them if the - # name is not already in the list - if not group_tasks or tn not in task_names: - task_names.append(tn) - # Guarantees that for grouped tasks the tasks that are inserted first - # are shown at the top - if group_tasks: - task_names.reverse() - - - color_index = 0 - for index in range(len(tasks)): - tn = tasks[index]['name'] - del tasks[index]['name'] - tasks[index].update(shape_template) - - # If group_tasks is True, all tasks with the same name belong - # to the same row. - groupID = index - if group_tasks: - groupID = task_names.index(tn) - tasks[index]['y0'] = groupID - bar_width - tasks[index]['y1'] = groupID + bar_width - - # check if colors need to be looped - if color_index >= len(colors): - color_index = 0 - tasks[index]['fillcolor'] = colors[color_index] - # Add a line for hover text and autorange - entry = dict( - x=[tasks[index]['x0'], tasks[index]['x1']], - y=[groupID, groupID], - name='', - marker={'color': 'white'} - ) - if "description" in tasks[index]: - entry['text'] = tasks[index]['description'] - del tasks[index]['description'] - data.append(entry) - color_index += 1 - - layout = dict( - title=title, - showlegend=False, - height=height, - width=width, - shapes=[], - hovermode='closest', - yaxis=dict( - showgrid=showgrid_y, - ticktext=task_names, - tickvals=list(range(len(task_names))), - range=[-1, len(task_names) + 1], - autorange=False, - zeroline=False, - ), - xaxis=dict( - showgrid=showgrid_x, - zeroline=False, - rangeselector=dict( - buttons=list([ - dict(count=7, - label='1w', - step='day', - stepmode='backward'), - dict(count=1, - label='1m', - step='month', - stepmode='backward'), - dict(count=6, - label='6m', - step='month', - stepmode='backward'), - dict(count=1, - label='YTD', - step='year', - stepmode='todate'), - dict(count=1, - label='1y', - step='year', - stepmode='backward'), - dict(step='all') - ]) - ), - type='date' - ) - ) - layout['shapes'] = tasks - - fig = dict(data=data, layout=layout) - return fig - - @staticmethod - def _gantt_colorscale(chart, colors, title, index_col, show_colorbar, - bar_width, showgrid_x, showgrid_y, height, - width, tasks=None, task_names=None, data=None, group_tasks=False): - """ - Refer to FigureFactory.create_gantt() for docstring - """ - from numbers import Number - if tasks is None: - tasks = [] - if task_names is None: - task_names = [] - if data is None: - data = [] - showlegend = False - - for index in range(len(chart)): - task = dict(x0=chart[index]['Start'], - x1=chart[index]['Finish'], - name=chart[index]['Task']) - if 'Description' in chart[index]: - task['description'] = chart[index]['Description'] - tasks.append(task) - - shape_template = { - 'type': 'rect', - 'xref': 'x', - 'yref': 'y', - 'opacity': 1, - 'line': { - 'width': 0, - }, - 'yref': 'y', - } - - # compute the color for task based on indexing column - if isinstance(chart[0][index_col], Number): - # check that colors has at least 2 colors - if len(colors) < 2: - raise exceptions.PlotlyError( - "You must use at least 2 colors in 'colors' if you " - "are using a colorscale. However only the first two " - "colors given will be used for the lower and upper " - "bounds on the colormap." - ) - - # create the list of task names - for index in range(len(tasks)): - tn = tasks[index]['name'] - # Is added to task_names if group_tasks is set to False, - # or if the option is used (True) it only adds them if the - # name is not already in the list - if not group_tasks or tn not in task_names: - task_names.append(tn) - # Guarantees that for grouped tasks the tasks that are inserted - # first are shown at the top - if group_tasks: - task_names.reverse() - - for index in range(len(tasks)): - tn = tasks[index]['name'] - del tasks[index]['name'] - tasks[index].update(shape_template) - - # If group_tasks is True, all tasks with the same name belong - # to the same row. - groupID = index - if group_tasks: - groupID = task_names.index(tn) - tasks[index]['y0'] = groupID - bar_width - tasks[index]['y1'] = groupID + bar_width - - # unlabel color - colors = FigureFactory._color_parser( - colors, FigureFactory._unlabel_rgb - ) - lowcolor = colors[0] - highcolor = colors[1] - - intermed = (chart[index][index_col])/100.0 - intermed_color = FigureFactory._find_intermediate_color( - lowcolor, highcolor, intermed - ) - intermed_color = FigureFactory._color_parser( - intermed_color, FigureFactory._label_rgb - ) - tasks[index]['fillcolor'] = intermed_color - # relabel colors with 'rgb' - colors = FigureFactory._color_parser( - colors, FigureFactory._label_rgb - ) - - # add a line for hover text and autorange - entry = dict( - x=[tasks[index]['x0'], tasks[index]['x1']], - y=[groupID, groupID], - name='', - marker={'color': 'white'} - ) - if "description" in tasks[index]: - entry['text'] = tasks[index]['description'] - del tasks[index]['description'] - data.append(entry) - - - if show_colorbar is True: - # generate dummy data for colorscale visibility - data.append( - dict( - x=[tasks[index]['x0'], tasks[index]['x0']], - y=[index, index], - name='', - marker={'color': 'white', - 'colorscale': [[0, colors[0]], [1, colors[1]]], - 'showscale': True, - 'cmax': 100, - 'cmin': 0} - ) - ) - - if isinstance(chart[0][index_col], str): - index_vals = [] - for row in range(len(tasks)): - if chart[row][index_col] not in index_vals: - index_vals.append(chart[row][index_col]) - - index_vals.sort() - - if len(colors) < len(index_vals): - raise exceptions.PlotlyError( - "Error. The number of colors in 'colors' must be no less " - "than the number of unique index values in your group " - "column." - ) - - # make a dictionary assignment to each index value - index_vals_dict = {} - # define color index - c_index = 0 - for key in index_vals: - if c_index > len(colors) - 1: - c_index = 0 - index_vals_dict[key] = colors[c_index] - c_index += 1 - - # create the list of task names - for index in range(len(tasks)): - tn = tasks[index]['name'] - # Is added to task_names if group_tasks is set to False, - # or if the option is used (True) it only adds them if the - # name is not already in the list - if not group_tasks or tn not in task_names: - task_names.append(tn) - # Guarantees that for grouped tasks the tasks that are inserted - # first are shown at the top - if group_tasks: - task_names.reverse() - - for index in range(len(tasks)): - tn = tasks[index]['name'] - del tasks[index]['name'] - tasks[index].update(shape_template) - # If group_tasks is True, all tasks with the same name belong - # to the same row. - groupID = index - if group_tasks: - groupID = task_names.index(tn) - tasks[index]['y0'] = groupID - bar_width - tasks[index]['y1'] = groupID + bar_width - - tasks[index]['fillcolor'] = index_vals_dict[ - chart[index][index_col] - ] - - # add a line for hover text and autorange - entry = dict( - x=[tasks[index]['x0'], tasks[index]['x1']], - y=[groupID, groupID], - name='', - marker={'color': 'white'} - ) - if "description" in tasks[index]: - entry['text'] = tasks[index]['description'] - del tasks[index]['description'] - data.append(entry) - - if show_colorbar is True: - # generate dummy data to generate legend - showlegend = True - for k, index_value in enumerate(index_vals): - data.append( - dict( - x=[tasks[index]['x0'], tasks[index]['x0']], - y=[k, k], - showlegend=True, - name=str(index_value), - hoverinfo='none', - marker=dict( - color=colors[k], - size=1 - ) - ) - ) - - layout = dict( - title=title, - showlegend=showlegend, - height=height, - width=width, - shapes=[], - hovermode='closest', - yaxis=dict( - showgrid=showgrid_y, - ticktext=task_names, - tickvals=list(range(len(task_names))), - range=[-1, len(task_names) + 1], - autorange=False, - zeroline=False, - ), - xaxis=dict( - showgrid=showgrid_x, - zeroline=False, - rangeselector=dict( - buttons=list([ - dict(count=7, - label='1w', - step='day', - stepmode='backward'), - dict(count=1, - label='1m', - step='month', - stepmode='backward'), - dict(count=6, - label='6m', - step='month', - stepmode='backward'), - dict(count=1, - label='YTD', - step='year', - stepmode='todate'), - dict(count=1, - label='1y', - step='year', - stepmode='backward'), - dict(step='all') - ]) - ), - type='date' - ) - ) - layout['shapes'] = tasks - - fig = dict(data=data, layout=layout) - return fig - - @staticmethod - def _gantt_dict(chart, colors, title, index_col, show_colorbar, bar_width, - showgrid_x, showgrid_y, height, width, tasks=None, - task_names=None, data=None, group_tasks=False): - """ - Refer to FigureFactory.create_gantt() for docstring - """ - if tasks is None: - tasks = [] - if task_names is None: - task_names = [] - if data is None: - data = [] - showlegend = False - - for index in range(len(chart)): - task = dict(x0=chart[index]['Start'], - x1=chart[index]['Finish'], - name=chart[index]['Task']) - if 'Description' in chart[index]: - task['description'] = chart[index]['Description'] - tasks.append(task) - - shape_template = { - 'type': 'rect', - 'xref': 'x', - 'yref': 'y', - 'opacity': 1, - 'line': { - 'width': 0, - }, - 'yref': 'y', - } - - index_vals = [] - for row in range(len(tasks)): - if chart[row][index_col] not in index_vals: - index_vals.append(chart[row][index_col]) - - index_vals.sort() - - # verify each value in index column appears in colors dictionary - for key in index_vals: - if key not in colors: - raise exceptions.PlotlyError( - "If you are using colors as a dictionary, all of its " - "keys must be all the values in the index column." - ) - - # create the list of task names - for index in range(len(tasks)): - tn = tasks[index]['name'] - # Is added to task_names if group_tasks is set to False, - # or if the option is used (True) it only adds them if the - # name is not already in the list - if not group_tasks or tn not in task_names: - task_names.append(tn) - # Guarantees that for grouped tasks the tasks that are inserted first - # are shown at the top - if group_tasks: - task_names.reverse() - - for index in range(len(tasks)): - tn = tasks[index]['name'] - del tasks[index]['name'] - tasks[index].update(shape_template) - - # If group_tasks is True, all tasks with the same name belong - # to the same row. - groupID = index - if group_tasks: - groupID = task_names.index(tn) - tasks[index]['y0'] = groupID - bar_width - tasks[index]['y1'] = groupID + bar_width - - tasks[index]['fillcolor'] = colors[chart[index][index_col]] - - # add a line for hover text and autorange - entry = dict( - x=[tasks[index]['x0'], tasks[index]['x1']], - y=[groupID, groupID], - name='', - marker={'color': 'white'} - ) - if "description" in tasks[index]: - entry['text'] = tasks[index]['description'] - del tasks[index]['description'] - data.append(entry) - - if show_colorbar is True: - # generate dummy data to generate legend - showlegend = True - for k, index_value in enumerate(index_vals): - data.append( - dict( - x=[tasks[index]['x0'], tasks[index]['x0']], - y=[k, k], - showlegend=True, - hoverinfo='none', - name=str(index_value), - marker=dict( - color=colors[index_value], - size=1 - ) - ) - ) - - layout = dict( - title=title, - showlegend=showlegend, - height=height, - width=width, - shapes=[], - hovermode='closest', - yaxis=dict( - showgrid=showgrid_y, - ticktext=task_names, - tickvals=list(range(len(task_names))), - range=[-1, len(task_names) + 1], - autorange=False, - zeroline=False, - ), - xaxis=dict( - showgrid=showgrid_x, - zeroline=False, - rangeselector=dict( - buttons=list([ - dict(count=7, - label='1w', - step='day', - stepmode='backward'), - dict(count=1, - label='1m', - step='month', - stepmode='backward'), - dict(count=6, - label='6m', - step='month', - stepmode='backward'), - dict(count=1, - label='YTD', - step='year', - stepmode='todate'), - dict(count=1, - label='1y', - step='year', - stepmode='backward'), - dict(step='all') - ]) - ), - type='date' - ) - ) - layout['shapes'] = tasks - - fig = dict(data=data, layout=layout) - return fig - - @staticmethod - def create_gantt(df, colors=None, index_col=None, show_colorbar=False, - reverse_colors=False, title='Gantt Chart', - bar_width=0.2, showgrid_x=False, showgrid_y=False, - height=600, width=900, tasks=None, - task_names=None, data=None, group_tasks=False): - """ - Returns figure for a gantt chart - - :param (array|list) df: input data for gantt chart. Must be either a - a dataframe or a list. If dataframe, the columns must include - 'Task', 'Start' and 'Finish'. Other columns can be included and - used for indexing. If a list, its elements must be dictionaries - with the same required column headers: 'Task', 'Start' and - 'Finish'. - :param (str|list|dict|tuple) colors: either a plotly scale name, an - rgb or hex color, a color tuple or a list of colors. An rgb color - is of the form 'rgb(x, y, z)' where x, y, z belong to the interval - [0, 255] and a color tuple is a tuple of the form (a, b, c) where - a, b and c belong to [0, 1]. If colors is a list, it must - contain the valid color types aforementioned as its members. - If a dictionary, all values of the indexing column must be keys in - colors. - :param (str|float) index_col: the column header (if df is a data - frame) that will function as the indexing column. If df is a list, - index_col must be one of the keys in all the items of df. - :param (bool) show_colorbar: determines if colorbar will be visible. - Only applies if values in the index column are numeric. - :param (bool) reverse_colors: reverses the order of selected colors - :param (str) title: the title of the chart - :param (float) bar_width: the width of the horizontal bars in the plot - :param (bool) showgrid_x: show/hide the x-axis grid - :param (bool) showgrid_y: show/hide the y-axis grid - :param (float) height: the height of the chart - :param (float) width: the width of the chart - - Example 1: Simple Gantt Chart - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - - # Make data for chart - df = [dict(Task="Job A", Start='2009-01-01', Finish='2009-02-30'), - dict(Task="Job B", Start='2009-03-05', Finish='2009-04-15'), - dict(Task="Job C", Start='2009-02-20', Finish='2009-05-30')] - - # Create a figure - fig = FF.create_gantt(df) - - # Plot the data - py.iplot(fig, filename='Simple Gantt Chart', world_readable=True) - ``` - - Example 2: Index by Column with Numerical Entries - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - - # Make data for chart - df = [dict(Task="Job A", Start='2009-01-01', - Finish='2009-02-30', Complete=10), - dict(Task="Job B", Start='2009-03-05', - Finish='2009-04-15', Complete=60), - dict(Task="Job C", Start='2009-02-20', - Finish='2009-05-30', Complete=95)] - - # Create a figure with Plotly colorscale - fig = FF.create_gantt(df, colors='Blues', index_col='Complete', - show_colorbar=True, bar_width=0.5, - showgrid_x=True, showgrid_y=True) - - # Plot the data - py.iplot(fig, filename='Numerical Entries', world_readable=True) - ``` - - Example 3: Index by Column with String Entries - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - - # Make data for chart - df = [dict(Task="Job A", Start='2009-01-01', - Finish='2009-02-30', Resource='Apple'), - dict(Task="Job B", Start='2009-03-05', - Finish='2009-04-15', Resource='Grape'), - dict(Task="Job C", Start='2009-02-20', - Finish='2009-05-30', Resource='Banana')] - - # Create a figure with Plotly colorscale - fig = FF.create_gantt(df, colors=['rgb(200, 50, 25)', - (1, 0, 1), - '#6c4774'], - index_col='Resource', - reverse_colors=True, - show_colorbar=True) - - # Plot the data - py.iplot(fig, filename='String Entries', world_readable=True) - ``` - - Example 4: Use a dictionary for colors - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - - # Make data for chart - df = [dict(Task="Job A", Start='2009-01-01', - Finish='2009-02-30', Resource='Apple'), - dict(Task="Job B", Start='2009-03-05', - Finish='2009-04-15', Resource='Grape'), - dict(Task="Job C", Start='2009-02-20', - Finish='2009-05-30', Resource='Banana')] - - # Make a dictionary of colors - colors = {'Apple': 'rgb(255, 0, 0)', - 'Grape': 'rgb(170, 14, 200)', - 'Banana': (1, 1, 0.2)} - - # Create a figure with Plotly colorscale - fig = FF.create_gantt(df, colors=colors, - index_col='Resource', - show_colorbar=True) - - # Plot the data - py.iplot(fig, filename='dictioanry colors', world_readable=True) - ``` - - Example 5: Use a pandas dataframe - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - - import pandas as pd - - # Make data as a dataframe - df = pd.DataFrame([['Run', '2010-01-01', '2011-02-02', 10], - ['Fast', '2011-01-01', '2012-06-05', 55], - ['Eat', '2012-01-05', '2013-07-05', 94]], - columns=['Task', 'Start', 'Finish', 'Complete']) - - # Create a figure with Plotly colorscale - fig = FF.create_gantt(df, colors='Blues', index_col='Complete', - show_colorbar=True, bar_width=0.5, - showgrid_x=True, showgrid_y=True) - - # Plot the data - py.iplot(fig, filename='data with dataframe', world_readable=True) - ``` - """ - # validate gantt input data - chart = FigureFactory._validate_gantt(df) - - if index_col: - if index_col not in chart[0]: - raise exceptions.PlotlyError( - "In order to use an indexing column and assign colors to " - "the values of the index, you must choose an actual " - "column name in the dataframe or key if a list of " - "dictionaries is being used.") - - # validate gantt index column - index_list = [] - for dictionary in chart: - index_list.append(dictionary[index_col]) - FigureFactory._validate_index(index_list) - - # Validate colors - if isinstance(colors, dict): - colors = FigureFactory._validate_colors_dict(colors, 'rgb') - else: - colors = FigureFactory._validate_colors(colors, 'rgb') - - if reverse_colors is True: - colors.reverse() - - if not index_col: - if isinstance(colors, dict): - raise exceptions.PlotlyError( - "Error. You have set colors to a dictionary but have not " - "picked an index. An index is required if you are " - "assigning colors to particular values in a dictioanry." - ) - fig = FigureFactory._gantt( - chart, colors, title, bar_width, showgrid_x, showgrid_y, - height, width, tasks=None, task_names=None, data=None, - group_tasks=group_tasks - ) - return fig - else: - if not isinstance(colors, dict): - fig = FigureFactory._gantt_colorscale( - chart, colors, title, index_col, show_colorbar, bar_width, - showgrid_x, showgrid_y, height, width, - tasks=None, task_names=None, data=None, group_tasks=group_tasks - ) - return fig - else: - fig = FigureFactory._gantt_dict( - chart, colors, title, index_col, show_colorbar, bar_width, - showgrid_x, showgrid_y, height, width, - tasks=None, task_names=None, data=None, group_tasks=group_tasks - ) - return fig - - @staticmethod - def _validate_colors(colors, colortype='tuple'): - """ - Validates color(s) and returns a list of color(s) of a specified type - """ - from numbers import Number - if colors is None: - colors = DEFAULT_PLOTLY_COLORS - - if isinstance(colors, str): - if colors in PLOTLY_SCALES: - colors = PLOTLY_SCALES[colors] - elif 'rgb' in colors or '#' in colors: - colors = [colors] - else: - raise exceptions.PlotlyError( - "If your colors variable is a string, it must be a " - "Plotly scale, an rgb color or a hex color.") - - elif isinstance(colors, tuple): - if isinstance(colors[0], Number): - colors = [colors] - else: - colors = list(colors) - - - # convert color elements in list to tuple color - for j, each_color in enumerate(colors): - if 'rgb' in each_color: - each_color = FigureFactory._color_parser( - each_color, FigureFactory._unlabel_rgb - ) - for value in each_color: - if value > 255.0: - raise exceptions.PlotlyError( - "Whoops! The elements in your rgb colors " - "tuples cannot exceed 255.0." - ) - each_color = FigureFactory._color_parser( - each_color, FigureFactory._unconvert_from_RGB_255 - ) - colors[j] = each_color - - if '#' in each_color: - each_color = FigureFactory._color_parser( - each_color, FigureFactory._hex_to_rgb - ) - each_color = FigureFactory._color_parser( - each_color, FigureFactory._unconvert_from_RGB_255 - ) - - colors[j] = each_color - - if isinstance(each_color, tuple): - for value in each_color: - if value > 1.0: - raise exceptions.PlotlyError( - "Whoops! The elements in your colors tuples " - "cannot exceed 1.0." - ) - colors[j] = each_color - - if colortype == 'rgb': - for j, each_color in enumerate(colors): - rgb_color = FigureFactory._color_parser( - each_color, FigureFactory._convert_to_RGB_255 - ) - colors[j] = FigureFactory._color_parser( - rgb_color, FigureFactory._label_rgb - ) - - return colors - - @staticmethod - def _validate_colors_dict(colors, colortype='tuple'): - """ - Validates dictioanry of color(s) - """ - # validate each color element in the dictionary - for key in colors: - if 'rgb' in colors[key]: - colors[key] = FigureFactory._color_parser( - colors[key], FigureFactory._unlabel_rgb - ) - for value in colors[key]: - if value > 255.0: - raise exceptions.PlotlyError( - "Whoops! The elements in your rgb colors " - "tuples cannot exceed 255.0." - ) - colors[key] = FigureFactory._color_parser( - colors[key], FigureFactory._unconvert_from_RGB_255 - ) - - if '#' in colors[key]: - colors[key] = FigureFactory._color_parser( - colors[key], FigureFactory._hex_to_rgb - ) - colors[key] = FigureFactory._color_parser( - colors[key], FigureFactory._unconvert_from_RGB_255 - ) - - if isinstance(colors[key], tuple): - for value in colors[key]: - if value > 1.0: - raise exceptions.PlotlyError( - "Whoops! The elements in your colors tuples " - "cannot exceed 1.0." - ) - - if colortype == 'rgb': - for key in colors: - colors[key] = FigureFactory._color_parser( - colors[key], FigureFactory._convert_to_RGB_255 - ) - colors[key] = FigureFactory._color_parser( - colors[key], FigureFactory._label_rgb - ) - - return colors - - @staticmethod - def _calc_stats(data): - """ - Calculate statistics for use in violin plot. - """ - import numpy as np - - x = np.asarray(data, np.float) - vals_min = np.min(x) - vals_max = np.max(x) - q2 = np.percentile(x, 50, interpolation='linear') - q1 = np.percentile(x, 25, interpolation='lower') - q3 = np.percentile(x, 75, interpolation='higher') - iqr = q3 - q1 - whisker_dist = 1.5 * iqr - - # in order to prevent drawing whiskers outside the interval - # of data one defines the whisker positions as: - d1 = np.min(x[x >= (q1 - whisker_dist)]) - d2 = np.max(x[x <= (q3 + whisker_dist)]) - return { - 'min': vals_min, - 'max': vals_max, - 'q1': q1, - 'q2': q2, - 'q3': q3, - 'd1': d1, - 'd2': d2 - } - - @staticmethod - def _make_half_violin(x, y, fillcolor='#1f77b4', - linecolor='rgb(0, 0, 0)'): - """ - Produces a sideways probability distribution fig violin plot. - """ - from plotly.graph_objs import graph_objs - - text = ['(pdf(y), y)=(' + '{:0.2f}'.format(x[i]) + - ', ' + '{:0.2f}'.format(y[i]) + ')' - for i in range(len(x))] - - return graph_objs.Scatter( - x=x, - y=y, - mode='lines', - name='', - text=text, - fill='tonextx', - fillcolor=fillcolor, - line=graph_objs.Line(width=0.5, color=linecolor, shape='spline'), - hoverinfo='text', - opacity=0.5 - ) - - @staticmethod - def _make_violin_rugplot(vals, pdf_max, distance, - color='#1f77b4'): - """ - Returns a rugplot fig for a violin plot. - """ - from plotly.graph_objs import graph_objs - - return graph_objs.Scatter( - y=vals, - x=[-pdf_max-distance]*len(vals), - marker=graph_objs.Marker( - color=color, - symbol='line-ew-open' - ), - mode='markers', - name='', - showlegend=False, - hoverinfo='y' - ) - - @staticmethod - def _make_quartiles(q1, q3): - """ - Makes the upper and lower quartiles for a violin plot. - """ - from plotly.graph_objs import graph_objs - - return graph_objs.Scatter( - x=[0, 0], - y=[q1, q3], - text=['lower-quartile: ' + '{:0.2f}'.format(q1), - 'upper-quartile: ' + '{:0.2f}'.format(q3)], - mode='lines', - line=graph_objs.Line( - width=4, - color='rgb(0,0,0)' - ), - hoverinfo='text' - ) - - @staticmethod - def _make_median(q2): - """ - Formats the 'median' hovertext for a violin plot. - """ - from plotly.graph_objs import graph_objs - - return graph_objs.Scatter( - x=[0], - y=[q2], - text=['median: ' + '{:0.2f}'.format(q2)], - mode='markers', - marker=dict(symbol='square', - color='rgb(255,255,255)'), - hoverinfo='text' - ) @staticmethod - def _make_non_outlier_interval(d1, d2): - """ - Returns the scatterplot fig of most of a violin plot. - """ - from plotly.graph_objs import graph_objs - - return graph_objs.Scatter( - x=[0, 0], - y=[d1, d2], - name='', - mode='lines', - line=graph_objs.Line(width=1.5, - color='rgb(0,0,0)') + def _deprecated(old_method, new_method=None): + if new_method is None: + # The method name stayed the same. + new_method = old_method + warnings.warn( + 'plotly.tools.FigureFactory.{} is deprecated. ' + 'Use plotly.figure_factory.{}'.format(old_method, new_method) ) @staticmethod - def _make_XAxis(xaxis_title, xaxis_range): - """ - Makes the x-axis for a violin plot. - """ - from plotly.graph_objs import graph_objs - - xaxis = graph_objs.XAxis(title=xaxis_title, - range=xaxis_range, - showgrid=False, - zeroline=False, - showline=False, - mirror=False, - ticks='', - showticklabels=False, - ) - return xaxis - - @staticmethod - def _make_YAxis(yaxis_title): - """ - Makes the y-axis for a violin plot. - """ - from plotly.graph_objs import graph_objs - - yaxis = graph_objs.YAxis(title=yaxis_title, - showticklabels=True, - autorange=True, - ticklen=4, - showline=True, - zeroline=False, - showgrid=False, - mirror=False) - return yaxis - - @staticmethod - def _violinplot(vals, fillcolor='#1f77b4', rugplot=True): - """ - Refer to FigureFactory.create_violin() for docstring. - """ - import numpy as np - from scipy import stats - - vals = np.asarray(vals, np.float) - # summary statistics - vals_min = FigureFactory._calc_stats(vals)['min'] - vals_max = FigureFactory._calc_stats(vals)['max'] - q1 = FigureFactory._calc_stats(vals)['q1'] - q2 = FigureFactory._calc_stats(vals)['q2'] - q3 = FigureFactory._calc_stats(vals)['q3'] - d1 = FigureFactory._calc_stats(vals)['d1'] - d2 = FigureFactory._calc_stats(vals)['d2'] - - # kernel density estimation of pdf - pdf = stats.gaussian_kde(vals) - # grid over the data interval - xx = np.linspace(vals_min, vals_max, 100) - # evaluate the pdf at the grid xx - yy = pdf(xx) - max_pdf = np.max(yy) - # distance from the violin plot to rugplot - distance = (2.0 * max_pdf)/10 if rugplot else 0 - # range for x values in the plot - plot_xrange = [-max_pdf - distance - 0.1, max_pdf + 0.1] - plot_data = [FigureFactory._make_half_violin( - -yy, xx, fillcolor=fillcolor), - FigureFactory._make_half_violin( - yy, xx, fillcolor=fillcolor), - FigureFactory._make_non_outlier_interval(d1, d2), - FigureFactory._make_quartiles(q1, q3), - FigureFactory._make_median(q2)] - if rugplot: - plot_data.append(FigureFactory._make_violin_rugplot( - vals, - max_pdf, - distance=distance, - color=fillcolor) - ) - return plot_data, plot_xrange - - @staticmethod - def _violin_no_colorscale(data, data_header, group_header, colors, - use_colorscale, group_stats, - height, width, title): - """ - Refer to FigureFactory.create_violin() for docstring. - - Returns fig for violin plot without colorscale. - - """ - from plotly.graph_objs import graph_objs - import numpy as np - - # collect all group names - group_name = [] - for name in data[group_header]: - if name not in group_name: - group_name.append(name) - group_name.sort() - - gb = data.groupby([group_header]) - L = len(group_name) - - fig = make_subplots(rows=1, cols=L, - shared_yaxes=True, - horizontal_spacing=0.025, - print_grid=False) - color_index = 0 - for k, gr in enumerate(group_name): - vals = np.asarray(gb.get_group(gr)[data_header], np.float) - if color_index >= len(colors): - color_index = 0 - plot_data, plot_xrange = FigureFactory._violinplot( - vals, - fillcolor=colors[color_index] - ) - layout = graph_objs.Layout() - - for item in plot_data: - fig.append_trace(item, 1, k + 1) - color_index += 1 - - # add violin plot labels - fig['layout'].update({'xaxis{}'.format(k + 1): - FigureFactory._make_XAxis(group_name[k], - plot_xrange)}) - - # set the sharey axis style - fig['layout'].update( - {'yaxis{}'.format(1): FigureFactory._make_YAxis('')} - ) - fig['layout'].update( - title=title, - showlegend=False, - hovermode='closest', - autosize=False, - height=height, - width=width - ) - - return fig + def create_2D_density(*args, **kwargs): + FigureFactory._deprecated('create_2D_density', 'create_2d_density') + from plotly.figure_factory import create_2d_density + return create_2d_density(*args, **kwargs) @staticmethod - def _violin_colorscale(data, data_header, group_header, colors, - use_colorscale, group_stats, height, width, - title): - """ - Refer to FigureFactory.create_violin() for docstring. - - Returns fig for violin plot with colorscale. - - """ - from plotly.graph_objs import graph_objs - import numpy as np - - # collect all group names - group_name = [] - for name in data[group_header]: - if name not in group_name: - group_name.append(name) - group_name.sort() - - # make sure all group names are keys in group_stats - for group in group_name: - if group not in group_stats: - raise exceptions.PlotlyError("All values/groups in the index " - "column must be represented " - "as a key in group_stats.") - - gb = data.groupby([group_header]) - L = len(group_name) - - fig = make_subplots(rows=1, cols=L, - shared_yaxes=True, - horizontal_spacing=0.025, - print_grid=False) - - # prepare low and high color for colorscale - lowcolor = FigureFactory._color_parser( - colors[0], FigureFactory._unlabel_rgb - ) - highcolor = FigureFactory._color_parser( - colors[1], FigureFactory._unlabel_rgb - ) - - # find min and max values in group_stats - group_stats_values = [] - for key in group_stats: - group_stats_values.append(group_stats[key]) - - max_value = max(group_stats_values) - min_value = min(group_stats_values) - - for k, gr in enumerate(group_name): - vals = np.asarray(gb.get_group(gr)[data_header], np.float) - - # find intermediate color from colorscale - intermed = (group_stats[gr] - min_value) / (max_value - min_value) - intermed_color = FigureFactory._find_intermediate_color( - lowcolor, highcolor, intermed - ) - - plot_data, plot_xrange = FigureFactory._violinplot( - vals, - fillcolor='rgb{}'.format(intermed_color) - ) - layout = graph_objs.Layout() - - for item in plot_data: - fig.append_trace(item, 1, k + 1) - fig['layout'].update({'xaxis{}'.format(k + 1): - FigureFactory._make_XAxis(group_name[k], - plot_xrange)}) - # add colorbar to plot - trace_dummy = graph_objs.Scatter( - x=[0], - y=[0], - mode='markers', - marker=dict( - size=2, - cmin=min_value, - cmax=max_value, - colorscale=[[0, colors[0]], - [1, colors[1]]], - showscale=True), - showlegend=False, - ) - fig.append_trace(trace_dummy, 1, L) - - # set the sharey axis style - fig['layout'].update( - {'yaxis{}'.format(1): FigureFactory._make_YAxis('')} - ) - fig['layout'].update( - title=title, - showlegend=False, - hovermode='closest', - autosize=False, - height=height, - width=width - ) - - return fig + def create_annotated_heatmap(*args, **kwargs): + FigureFactory._deprecated('create_annotated_heatmap') + from plotly.figure_factory import create_annotated_heatmap + return create_annotated_heatmap(*args, **kwargs) @staticmethod - def _violin_dict(data, data_header, group_header, colors, use_colorscale, - group_stats, height, width, title): - """ - Refer to FigureFactory.create_violin() for docstring. - - Returns fig for violin plot without colorscale. - - """ - from plotly.graph_objs import graph_objs - import numpy as np - - # collect all group names - group_name = [] - for name in data[group_header]: - if name not in group_name: - group_name.append(name) - group_name.sort() - - # check if all group names appear in colors dict - for group in group_name: - if group not in colors: - raise exceptions.PlotlyError("If colors is a dictionary, all " - "the group names must appear as " - "keys in colors.") - - gb = data.groupby([group_header]) - L = len(group_name) - - fig = make_subplots(rows=1, cols=L, - shared_yaxes=True, - horizontal_spacing=0.025, - print_grid=False) - - for k, gr in enumerate(group_name): - vals = np.asarray(gb.get_group(gr)[data_header], np.float) - plot_data, plot_xrange = FigureFactory._violinplot( - vals, - fillcolor=colors[gr] - ) - layout = graph_objs.Layout() - - for item in plot_data: - fig.append_trace(item, 1, k + 1) - - # add violin plot labels - fig['layout'].update({'xaxis{}'.format(k + 1): - FigureFactory._make_XAxis(group_name[k], - plot_xrange)}) - - # set the sharey axis style - fig['layout'].update( - {'yaxis{}'.format(1): FigureFactory._make_YAxis('')} - ) - fig['layout'].update( - title=title, - showlegend=False, - hovermode='closest', - autosize=False, - height=height, - width=width - ) - - return fig + def create_candlestick(*args, **kwargs): + FigureFactory._deprecated('create_candlestick') + from plotly.figure_factory import create_candlestick + return create_candlestick(*args, **kwargs) @staticmethod - def create_violin(data, data_header=None, group_header=None, - colors=None, use_colorscale=False, group_stats=None, - height=450, width=600, title='Violin and Rug Plot'): - """ - Returns figure for a violin plot - - :param (list|array) data: accepts either a list of numerical values, - a list of dictionaries all with identical keys and at least one - column of numeric values, or a pandas dataframe with at least one - column of numbers - :param (str) data_header: the header of the data column to be used - from an inputted pandas dataframe. Not applicable if 'data' is - a list of numeric values - :param (str) group_header: applicable if grouping data by a variable. - 'group_header' must be set to the name of the grouping variable. - :param (str|tuple|list|dict) colors: either a plotly scale name, - an rgb or hex color, a color tuple, a list of colors or a - dictionary. An rgb color is of the form 'rgb(x, y, z)' where - x, y and z belong to the interval [0, 255] and a color tuple is a - tuple of the form (a, b, c) where a, b and c belong to [0, 1]. - If colors is a list, it must contain valid color types as its - members. - :param (bool) use_colorscale: Only applicable if grouping by another - variable. Will implement a colorscale based on the first 2 colors - of param colors. This means colors must be a list with at least 2 - colors in it (Plotly colorscales are accepted since they map to a - list of two rgb colors) - :param (dict) group_stats: a dictioanry where each key is a unique - value from the group_header column in data. Each value must be a - number and will be used to color the violin plots if a colorscale - is being used - :param (float) height: the height of the violin plot - :param (float) width: the width of the violin plot - :param (str) title: the title of the violin plot - - Example 1: Single Violin Plot - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - from plotly.graph_objs import graph_objs - - import numpy as np - from scipy import stats - - # create list of random values - data_list = np.random.randn(100) - data_list.tolist() - - # create violin fig - fig = FF.create_violin(data_list, colors='#604d9e') - - # plot - py.iplot(fig, filename='Violin Plot') - ``` - - Example 2: Multiple Violin Plots with Qualitative Coloring - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - from plotly.graph_objs import graph_objs - - import numpy as np - import pandas as pd - from scipy import stats - - # create dataframe - np.random.seed(619517) - Nr=250 - y = np.random.randn(Nr) - gr = np.random.choice(list("ABCDE"), Nr) - norm_params=[(0, 1.2), (0.7, 1), (-0.5, 1.4), (0.3, 1), (0.8, 0.9)] - - for i, letter in enumerate("ABCDE"): - y[gr == letter] *=norm_params[i][1]+ norm_params[i][0] - df = pd.DataFrame(dict(Score=y, Group=gr)) - - # create violin fig - fig = FF.create_violin(df, data_header='Score', group_header='Group', - height=600, width=1000) - - # plot - py.iplot(fig, filename='Violin Plot with Coloring') - ``` - - Example 3: Violin Plots with Colorscale - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - from plotly.graph_objs import graph_objs - - import numpy as np - import pandas as pd - from scipy import stats - - # create dataframe - np.random.seed(619517) - Nr=250 - y = np.random.randn(Nr) - gr = np.random.choice(list("ABCDE"), Nr) - norm_params=[(0, 1.2), (0.7, 1), (-0.5, 1.4), (0.3, 1), (0.8, 0.9)] - - for i, letter in enumerate("ABCDE"): - y[gr == letter] *=norm_params[i][1]+ norm_params[i][0] - df = pd.DataFrame(dict(Score=y, Group=gr)) - - # define header params - data_header = 'Score' - group_header = 'Group' - - # make groupby object with pandas - group_stats = {} - groupby_data = df.groupby([group_header]) - - for group in "ABCDE": - data_from_group = groupby_data.get_group(group)[data_header] - # take a stat of the grouped data - stat = np.median(data_from_group) - # add to dictionary - group_stats[group] = stat - - # create violin fig - fig = FF.create_violin(df, data_header='Score', group_header='Group', - height=600, width=1000, use_colorscale=True, - group_stats=group_stats) - - # plot - py.iplot(fig, filename='Violin Plot with Colorscale') - ``` - """ - from plotly.graph_objs import graph_objs - from numbers import Number - - # Validate colors - if isinstance(colors, dict): - valid_colors = FigureFactory._validate_colors_dict(colors, 'rgb') - else: - valid_colors = FigureFactory._validate_colors(colors, 'rgb') - - # validate data and choose plot type - if group_header is None: - if isinstance(data, list): - if len(data) <= 0: - raise exceptions.PlotlyError("If data is a list, it must be " - "nonempty and contain either " - "numbers or dictionaries.") - - if not all(isinstance(element, Number) for element in data): - raise exceptions.PlotlyError("If data is a list, it must " - "contain only numbers.") - - if _pandas_imported and isinstance(data, pd.core.frame.DataFrame): - if data_header is None: - raise exceptions.PlotlyError("data_header must be the " - "column name with the " - "desired numeric data for " - "the violin plot.") - - data = data[data_header].values.tolist() - - # call the plotting functions - plot_data, plot_xrange = FigureFactory._violinplot( - data, fillcolor=valid_colors[0] - ) - - layout = graph_objs.Layout( - title=title, - autosize=False, - font=graph_objs.Font(size=11), - height=height, - showlegend=False, - width=width, - xaxis=FigureFactory._make_XAxis('', plot_xrange), - yaxis=FigureFactory._make_YAxis(''), - hovermode='closest' - ) - layout['yaxis'].update(dict(showline=False, - showticklabels=False, - ticks='')) - - fig = graph_objs.Figure(data=graph_objs.Data(plot_data), - layout=layout) - - return fig - - else: - if not isinstance(data, pd.core.frame.DataFrame): - raise exceptions.PlotlyError("Error. You must use a pandas " - "DataFrame if you are using a " - "group header.") - - if data_header is None: - raise exceptions.PlotlyError("data_header must be the column " - "name with the desired numeric " - "data for the violin plot.") - - if use_colorscale is False: - if isinstance(valid_colors, dict): - # validate colors dict choice below - fig = FigureFactory._violin_dict( - data, data_header, group_header, valid_colors, - use_colorscale, group_stats, height, width, title - ) - return fig - else: - fig = FigureFactory._violin_no_colorscale( - data, data_header, group_header, valid_colors, - use_colorscale, group_stats, height, width, title - ) - return fig - else: - if isinstance(valid_colors, dict): - raise exceptions.PlotlyError("The colors param cannot be " - "a dictionary if you are " - "using a colorscale.") - - if len(valid_colors) < 2: - raise exceptions.PlotlyError("colors must be a list with " - "at least 2 colors. A " - "Plotly scale is allowed.") - - if not isinstance(group_stats, dict): - raise exceptions.PlotlyError("Your group_stats param " - "must be a dictionary.") - - fig = FigureFactory._violin_colorscale( - data, data_header, group_header, valid_colors, - use_colorscale, group_stats, height, width, title - ) - return fig + def create_dendrogram(*args, **kwargs): + FigureFactory._deprecated('create_dendrogram') + from plotly.figure_factory import create_dendrogram + return create_dendrogram(*args, **kwargs) @staticmethod - def _find_intermediate_color(lowcolor, highcolor, intermed): - """ - Returns the color at a given distance between two colors - - This function takes two color tuples, where each element is between 0 - and 1, along with a value 0 < intermed < 1 and returns a color that is - intermed-percent from lowcolor to highcolor - - """ - diff_0 = float(highcolor[0] - lowcolor[0]) - diff_1 = float(highcolor[1] - lowcolor[1]) - diff_2 = float(highcolor[2] - lowcolor[2]) - - return (lowcolor[0] + intermed * diff_0, - lowcolor[1] + intermed * diff_1, - lowcolor[2] + intermed * diff_2) + def create_distplot(*args, **kwargs): + FigureFactory._deprecated('create_distplot') + from plotly.figure_factory import create_distplot + return create_distplot(*args, **kwargs) @staticmethod - def _color_parser(colors, function): - """ - Takes color(s) and a function and applies the function on the color(s) - - In particular, this function identifies whether the given color object - is an iterable or not and applies the given color-parsing function to - the color or iterable of colors. If given an iterable, it will only be - able to work with it if all items in the iterable are of the same type - - rgb string, hex string or tuple - - """ - from numbers import Number - if isinstance(colors, str): - return function(colors) - - if isinstance(colors, tuple) and isinstance(colors[0], Number): - return function(colors) - - if hasattr(colors, '__iter__'): - if isinstance(colors, tuple): - new_color_tuple = tuple(function(item) for item in colors) - return new_color_tuple - - else: - new_color_list = [function(item) for item in colors] - return new_color_list + def create_gantt(*args, **kwargs): + FigureFactory._deprecated('create_gantt') + from plotly.figure_factory import create_gantt + return create_gantt(*args, **kwargs) @staticmethod - def _unconvert_from_RGB_255(colors): - """ - Return a tuple where each element gets divided by 255 - - Takes a (list of) color tuple(s) where each element is between 0 and - 255. Returns the same tuples where each tuple element is normalized to - a value between 0 and 1 - - """ - return (colors[0]/(255.0), - colors[1]/(255.0), - colors[2]/(255.0)) + def create_ohlc(*args, **kwargs): + FigureFactory._deprecated('create_ohlc') + from plotly.figure_factory import create_ohlc + return create_ohlc(*args, **kwargs) @staticmethod - def _map_face2color(face, colormap, scale, vmin, vmax): - """ - Normalize facecolor values by vmin/vmax and return rgb-color strings - - This function takes a tuple color along with a colormap and a minimum - (vmin) and maximum (vmax) range of possible mean distances for the - given parametrized surface. It returns an rgb color based on the mean - distance between vmin and vmax - - """ - if vmin >= vmax: - raise exceptions.PlotlyError("Incorrect relation between vmin " - "and vmax. The vmin value cannot be " - "bigger than or equal to the value " - "of vmax.") - if len(colormap) == 1: - # color each triangle face with the same color in colormap - face_color = colormap[0] - face_color = colors.convert_to_RGB_255(face_color) - face_color = colors.label_rgb(face_color) - return face_color - if face == vmax: - # pick last color in colormap - face_color = colormap[-1] - face_color = colors.convert_to_RGB_255(face_color) - face_color = colors.label_rgb(face_color) - return face_color - else: - if scale is None: - # find the normalized distance t of a triangle face between - # vmin and vmax where the distance is between 0 and 1 - t = (face - vmin) / float((vmax - vmin)) - low_color_index = int(t / (1./(len(colormap) - 1))) - - face_color = colors.find_intermediate_color( - colormap[low_color_index], - colormap[low_color_index + 1], - t * (len(colormap) - 1) - low_color_index - ) - - face_color = colors.convert_to_RGB_255(face_color) - face_color = colors.label_rgb(face_color) - else: - # find the face color for a non-linearly interpolated scale - t = (face - vmin) / float((vmax - vmin)) - - low_color_index = 0 - for k in range(len(scale) - 1): - if scale[k] <= t < scale[k+1]: - break - low_color_index += 1 - - low_scale_val = scale[low_color_index] - high_scale_val = scale[low_color_index + 1] - - face_color = colors.find_intermediate_color( - colormap[low_color_index], - colormap[low_color_index + 1], - (t - low_scale_val)/(high_scale_val - low_scale_val) - ) - - face_color = colors.convert_to_RGB_255(face_color) - face_color = colors.label_rgb(face_color) - return face_color + def create_quiver(*args, **kwargs): + FigureFactory._deprecated('create_quiver') + from plotly.figure_factory import create_quiver + return create_quiver(*args, **kwargs) @staticmethod - def _trisurf(x, y, z, simplices, show_colorbar, edges_color, scale, - colormap=None, color_func=None, plot_edges=False, - x_edge=None, y_edge=None, z_edge=None, facecolor=None): - """ - Refer to FigureFactory.create_trisurf() for docstring - """ - # numpy import check - if _numpy_imported is False: - raise ImportError("FigureFactory._trisurf() requires " - "numpy imported.") - import numpy as np - from plotly.graph_objs import graph_objs - points3D = np.vstack((x, y, z)).T - simplices = np.atleast_2d(simplices) - - # vertices of the surface triangles - tri_vertices = points3D[simplices] - - # Define colors for the triangle faces - if color_func is None: - # mean values of z-coordinates of triangle vertices - mean_dists = tri_vertices[:, :, 2].mean(-1) - elif isinstance(color_func, (list, np.ndarray)): - # Pre-computed list / array of values to map onto color - if len(color_func) != len(simplices): - raise ValueError("If color_func is a list/array, it must " - "be the same length as simplices.") - - # convert all colors in color_func to rgb - for index in range(len(color_func)): - if isinstance(color_func[index], str): - if '#' in color_func[index]: - foo = colors.hex_to_rgb(color_func[index]) - color_func[index] = colors.label_rgb(foo) - - if isinstance(color_func[index], tuple): - foo = colors.convert_to_RGB_255(color_func[index]) - color_func[index] = colors.label_rgb(foo) - - mean_dists = np.asarray(color_func) - else: - # apply user inputted function to calculate - # custom coloring for triangle vertices - mean_dists = [] - for triangle in tri_vertices: - dists = [] - for vertex in triangle: - dist = color_func(vertex[0], vertex[1], vertex[2]) - dists.append(dist) - mean_dists.append(np.mean(dists)) - mean_dists = np.asarray(mean_dists) - - # Check if facecolors are already strings and can be skipped - if isinstance(mean_dists[0], str): - facecolor = mean_dists - else: - min_mean_dists = np.min(mean_dists) - max_mean_dists = np.max(mean_dists) - - if facecolor is None: - facecolor = [] - for index in range(len(mean_dists)): - color = FigureFactory._map_face2color(mean_dists[index], - colormap, - scale, - min_mean_dists, - max_mean_dists) - facecolor.append(color) - - # Make sure facecolor is a list so output is consistent across Pythons - facecolor = np.asarray(facecolor) - ii, jj, kk = simplices.T - - triangles = graph_objs.Mesh3d(x=x, y=y, z=z, facecolor=facecolor, - i=ii, j=jj, k=kk, name='') - - mean_dists_are_numbers = not isinstance(mean_dists[0], str) - - if mean_dists_are_numbers and show_colorbar is True: - # make a colorscale from the colors - colorscale = colors.make_colorscale(colormap, scale) - colorscale = colors.convert_colorscale_to_rgb(colorscale) - - colorbar = graph_objs.Scatter3d( - x=x[:1], - y=y[:1], - z=z[:1], - mode='markers', - marker=dict( - size=0.1, - color=[min_mean_dists, max_mean_dists], - colorscale=colorscale, - showscale=True), - hoverinfo='None', - showlegend=False - ) - - # the triangle sides are not plotted - if plot_edges is False: - if mean_dists_are_numbers and show_colorbar is True: - return graph_objs.Data([triangles, colorbar]) - else: - return graph_objs.Data([triangles]) - - # define the lists x_edge, y_edge and z_edge, of x, y, resp z - # coordinates of edge end points for each triangle - # None separates data corresponding to two consecutive triangles - is_none = [ii is None for ii in [x_edge, y_edge, z_edge]] - if any(is_none): - if not all(is_none): - raise ValueError("If any (x_edge, y_edge, z_edge) is None, " - "all must be None") - else: - x_edge = [] - y_edge = [] - z_edge = [] - - # Pull indices we care about, then add a None column to separate tris - ixs_triangles = [0, 1, 2, 0] - pull_edges = tri_vertices[:, ixs_triangles, :] - x_edge_pull = np.hstack([pull_edges[:, :, 0], - np.tile(None, [pull_edges.shape[0], 1])]) - y_edge_pull = np.hstack([pull_edges[:, :, 1], - np.tile(None, [pull_edges.shape[0], 1])]) - z_edge_pull = np.hstack([pull_edges[:, :, 2], - np.tile(None, [pull_edges.shape[0], 1])]) - - # Now unravel the edges into a 1-d vector for plotting - x_edge = np.hstack([x_edge, x_edge_pull.reshape([1, -1])[0]]) - y_edge = np.hstack([y_edge, y_edge_pull.reshape([1, -1])[0]]) - z_edge = np.hstack([z_edge, z_edge_pull.reshape([1, -1])[0]]) - - if not (len(x_edge) == len(y_edge) == len(z_edge)): - raise exceptions.PlotlyError("The lengths of x_edge, y_edge and " - "z_edge are not the same.") - - # define the lines for plotting - lines = graph_objs.Scatter3d( - x=x_edge, y=y_edge, z=z_edge, mode='lines', - line=graph_objs.Line( - color=edges_color, - width=1.5 - ), - showlegend=False - ) - - if mean_dists_are_numbers and show_colorbar is True: - return graph_objs.Data([triangles, lines, colorbar]) - else: - return graph_objs.Data([triangles, lines]) + def create_scatterplotmatrix(*args, **kwargs): + FigureFactory._deprecated('create_scatterplotmatrix') + from plotly.figure_factory import create_scatterplotmatrix + return create_scatterplotmatrix(*args, **kwargs) @staticmethod - def create_trisurf(x, y, z, simplices, colormap=None, show_colorbar=True, - scale=None, color_func=None, title='Trisurf Plot', - plot_edges=True, showbackground=True, - backgroundcolor='rgb(230, 230, 230)', - gridcolor='rgb(255, 255, 255)', - zerolinecolor='rgb(255, 255, 255)', - edges_color='rgb(50, 50, 50)', - height=800, width=800, - aspectratio=dict(x=1, y=1, z=1)): - """ - Returns figure for a triangulated surface plot - - :param (array) x: data values of x in a 1D array - :param (array) y: data values of y in a 1D array - :param (array) z: data values of z in a 1D array - :param (array) simplices: an array of shape (ntri, 3) where ntri is - the number of triangles in the triangularization. Each row of the - array contains the indicies of the verticies of each triangle - :param (str|tuple|list) colormap: either a plotly scale name, an rgb - or hex color, a color tuple or a list of colors. An rgb color is - of the form 'rgb(x, y, z)' where x, y, z belong to the interval - [0, 255] and a color tuple is a tuple of the form (a, b, c) where - a, b and c belong to [0, 1]. If colormap is a list, it must - contain the valid color types aforementioned as its members - :param (bool) show_colorbar: determines if colorbar is visible - :param (list|array) scale: sets the scale values to be used if a non- - linearly interpolated colormap is desired. If left as None, a - linear interpolation between the colors will be excecuted - :param (function|list) color_func: The parameter that determines the - coloring of the surface. Takes either a function with 3 arguments - x, y, z or a list/array of color values the same length as - simplices. If None, coloring will only depend on the z axis - :param (str) title: title of the plot - :param (bool) plot_edges: determines if the triangles on the trisurf - are visible - :param (bool) showbackground: makes background in plot visible - :param (str) backgroundcolor: color of background. Takes a string of - the form 'rgb(x,y,z)' x,y,z are between 0 and 255 inclusive - :param (str) gridcolor: color of the gridlines besides the axes. Takes - a string of the form 'rgb(x,y,z)' x,y,z are between 0 and 255 - inclusive - :param (str) zerolinecolor: color of the axes. Takes a string of the - form 'rgb(x,y,z)' x,y,z are between 0 and 255 inclusive - :param (str) edges_color: color of the edges, if plot_edges is True - :param (int|float) height: the height of the plot (in pixels) - :param (int|float) width: the width of the plot (in pixels) - :param (dict) aspectratio: a dictionary of the aspect ratio values for - the x, y and z axes. 'x', 'y' and 'z' take (int|float) values - - Example 1: Sphere - ``` - # Necessary Imports for Trisurf - import numpy as np - from scipy.spatial import Delaunay - - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - from plotly.graph_objs import graph_objs - - # Make data for plot - u = np.linspace(0, 2*np.pi, 20) - v = np.linspace(0, np.pi, 20) - u,v = np.meshgrid(u,v) - u = u.flatten() - v = v.flatten() - - x = np.sin(v)*np.cos(u) - y = np.sin(v)*np.sin(u) - z = np.cos(v) - - points2D = np.vstack([u,v]).T - tri = Delaunay(points2D) - simplices = tri.simplices - - # Create a figure - fig1 = FF.create_trisurf(x=x, y=y, z=z, - colormap="Rainbow", - simplices=simplices) - # Plot the data - py.iplot(fig1, filename='trisurf-plot-sphere') - ``` - - Example 2: Torus - ``` - # Necessary Imports for Trisurf - import numpy as np - from scipy.spatial import Delaunay - - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - from plotly.graph_objs import graph_objs - - # Make data for plot - u = np.linspace(0, 2*np.pi, 20) - v = np.linspace(0, 2*np.pi, 20) - u,v = np.meshgrid(u,v) - u = u.flatten() - v = v.flatten() - - x = (3 + (np.cos(v)))*np.cos(u) - y = (3 + (np.cos(v)))*np.sin(u) - z = np.sin(v) - - points2D = np.vstack([u,v]).T - tri = Delaunay(points2D) - simplices = tri.simplices - - # Create a figure - fig1 = FF.create_trisurf(x=x, y=y, z=z, - colormap="Viridis", - simplices=simplices) - # Plot the data - py.iplot(fig1, filename='trisurf-plot-torus') - ``` - - Example 3: Mobius Band - ``` - # Necessary Imports for Trisurf - import numpy as np - from scipy.spatial import Delaunay - - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - from plotly.graph_objs import graph_objs - - # Make data for plot - u = np.linspace(0, 2*np.pi, 24) - v = np.linspace(-1, 1, 8) - u,v = np.meshgrid(u,v) - u = u.flatten() - v = v.flatten() - - tp = 1 + 0.5*v*np.cos(u/2.) - x = tp*np.cos(u) - y = tp*np.sin(u) - z = 0.5*v*np.sin(u/2.) - - points2D = np.vstack([u,v]).T - tri = Delaunay(points2D) - simplices = tri.simplices - - # Create a figure - fig1 = FF.create_trisurf(x=x, y=y, z=z, - colormap=[(0.2, 0.4, 0.6), (1, 1, 1)], - simplices=simplices) - # Plot the data - py.iplot(fig1, filename='trisurf-plot-mobius-band') - ``` - - Example 4: Using a Custom Colormap Function with Light Cone - ``` - # Necessary Imports for Trisurf - import numpy as np - from scipy.spatial import Delaunay - - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - from plotly.graph_objs import graph_objs - - # Make data for plot - u=np.linspace(-np.pi, np.pi, 30) - v=np.linspace(-np.pi, np.pi, 30) - u,v=np.meshgrid(u,v) - u=u.flatten() - v=v.flatten() - - x = u - y = u*np.cos(v) - z = u*np.sin(v) - - points2D = np.vstack([u,v]).T - tri = Delaunay(points2D) - simplices = tri.simplices - - # Define distance function - def dist_origin(x, y, z): - return np.sqrt((1.0 * x)**2 + (1.0 * y)**2 + (1.0 * z)**2) - - # Create a figure - fig1 = FF.create_trisurf(x=x, y=y, z=z, - colormap=['#FFFFFF', '#E4FFFE', - '#A4F6F9', '#FF99FE', - '#BA52ED'], - scale=[0, 0.6, 0.71, 0.89, 1], - simplices=simplices, - color_func=dist_origin) - # Plot the data - py.iplot(fig1, filename='trisurf-plot-custom-coloring') - ``` - - Example 5: Enter color_func as a list of colors - ``` - # Necessary Imports for Trisurf - import numpy as np - from scipy.spatial import Delaunay - import random - - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - from plotly.graph_objs import graph_objs - - # Make data for plot - u=np.linspace(-np.pi, np.pi, 30) - v=np.linspace(-np.pi, np.pi, 30) - u,v=np.meshgrid(u,v) - u=u.flatten() - v=v.flatten() - - x = u - y = u*np.cos(v) - z = u*np.sin(v) - - points2D = np.vstack([u,v]).T - tri = Delaunay(points2D) - simplices = tri.simplices - - - colors = [] - color_choices = ['rgb(0, 0, 0)', '#6c4774', '#d6c7dd'] - - for index in range(len(simplices)): - colors.append(random.choice(color_choices)) - - fig = FF.create_trisurf( - x, y, z, simplices, - color_func=colors, - show_colorbar=True, - edges_color='rgb(2, 85, 180)', - title=' Modern Art' - ) - - py.iplot(fig, filename="trisurf-plot-modern-art") - ``` - """ - from plotly.graph_objs import graph_objs - - # Validate colormap - colors.validate_colors(colormap) - colormap, scale = colors.convert_colors_to_same_type( - colormap, colortype='tuple', - return_default_colors=True, scale=scale - ) - - data1 = FigureFactory._trisurf(x, y, z, simplices, - show_colorbar=show_colorbar, - color_func=color_func, - colormap=colormap, - scale=scale, - edges_color=edges_color, - plot_edges=plot_edges) - - axis = dict( - showbackground=showbackground, - backgroundcolor=backgroundcolor, - gridcolor=gridcolor, - zerolinecolor=zerolinecolor, - ) - layout = graph_objs.Layout( - title=title, - width=width, - height=height, - scene=graph_objs.Scene( - xaxis=graph_objs.XAxis(axis), - yaxis=graph_objs.YAxis(axis), - zaxis=graph_objs.ZAxis(axis), - aspectratio=dict( - x=aspectratio['x'], - y=aspectratio['y'], - z=aspectratio['z']), - ) - ) - - return graph_objs.Figure(data=data1, layout=layout) + def create_streamline(*args, **kwargs): + FigureFactory._deprecated('create_streamline') + from plotly.figure_factory import create_streamline + return create_streamline(*args, **kwargs) @staticmethod - def _scatterplot(dataframe, headers, diag, size, - height, width, title, **kwargs): - """ - Refer to FigureFactory.create_scatterplotmatrix() for docstring - - Returns fig for scatterplotmatrix without index - - """ - from plotly.graph_objs import graph_objs - dim = len(dataframe) - fig = make_subplots(rows=dim, cols=dim, print_grid=False) - trace_list = [] - # Insert traces into trace_list - for listy in dataframe: - for listx in dataframe: - if (listx == listy) and (diag == 'histogram'): - trace = graph_objs.Histogram( - x=listx, - showlegend=False - ) - elif (listx == listy) and (diag == 'box'): - trace = graph_objs.Box( - y=listx, - name=None, - showlegend=False - ) - else: - if 'marker' in kwargs: - kwargs['marker']['size'] = size - trace = graph_objs.Scatter( - x=listx, - y=listy, - mode='markers', - showlegend=False, - **kwargs - ) - trace_list.append(trace) - else: - trace = graph_objs.Scatter( - x=listx, - y=listy, - mode='markers', - marker=dict( - size=size), - showlegend=False, - **kwargs - ) - trace_list.append(trace) - - trace_index = 0 - indices = range(1, dim + 1) - for y_index in indices: - for x_index in indices: - fig.append_trace(trace_list[trace_index], - y_index, - x_index) - trace_index += 1 - - # Insert headers into the figure - for j in range(dim): - xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j) - fig['layout'][xaxis_key].update(title=headers[j]) - for j in range(dim): - yaxis_key = 'yaxis{}'.format(1 + (dim * j)) - fig['layout'][yaxis_key].update(title=headers[j]) - - fig['layout'].update( - height=height, width=width, - title=title, - showlegend=True - ) - - FigureFactory._hide_tick_labels_from_box_subplots(fig) - - return fig + def create_table(*args, **kwargs): + FigureFactory._deprecated('create_table') + from plotly.figure_factory import create_table + return create_table(*args, **kwargs) @staticmethod - def _scatterplot_dict(dataframe, headers, diag, size, - height, width, title, index, index_vals, - endpts, colormap, colormap_type, **kwargs): - """ - Refer to FigureFactory.create_scatterplotmatrix() for docstring - - Returns fig for scatterplotmatrix with both index and colormap picked. - Used if colormap is a dictionary with index values as keys pointing to - colors. Forces colormap_type to behave categorically because it would - not make sense colors are assigned to each index value and thus - implies that a categorical approach should be taken - - """ - from plotly.graph_objs import graph_objs - - theme = colormap - dim = len(dataframe) - fig = make_subplots(rows=dim, cols=dim, print_grid=False) - trace_list = [] - legend_param = 0 - # Work over all permutations of list pairs - for listy in dataframe: - for listx in dataframe: - # create a dictionary for index_vals - unique_index_vals = {} - for name in index_vals: - if name not in unique_index_vals: - unique_index_vals[name] = [] - - # Fill all the rest of the names into the dictionary - for name in sorted(unique_index_vals.keys()): - new_listx = [] - new_listy = [] - for j in range(len(index_vals)): - if index_vals[j] == name: - new_listx.append(listx[j]) - new_listy.append(listy[j]) - # Generate trace with VISIBLE icon - if legend_param == 1: - if (listx == listy) and (diag == 'histogram'): - trace = graph_objs.Histogram( - x=new_listx, - marker=dict( - color=theme[name]), - showlegend=True - ) - elif (listx == listy) and (diag == 'box'): - trace = graph_objs.Box( - y=new_listx, - name=None, - marker=dict( - color=theme[name]), - showlegend=True - ) - else: - if 'marker' in kwargs: - kwargs['marker']['size'] = size - kwargs['marker']['color'] = theme[name] - trace = graph_objs.Scatter( - x=new_listx, - y=new_listy, - mode='markers', - name=name, - showlegend=True, - **kwargs - ) - else: - trace = graph_objs.Scatter( - x=new_listx, - y=new_listy, - mode='markers', - name=name, - marker=dict( - size=size, - color=theme[name]), - showlegend=True, - **kwargs - ) - # Generate trace with INVISIBLE icon - else: - if (listx == listy) and (diag == 'histogram'): - trace = graph_objs.Histogram( - x=new_listx, - marker=dict( - color=theme[name]), - showlegend=False - ) - elif (listx == listy) and (diag == 'box'): - trace = graph_objs.Box( - y=new_listx, - name=None, - marker=dict( - color=theme[name]), - showlegend=False - ) - else: - if 'marker' in kwargs: - kwargs['marker']['size'] = size - kwargs['marker']['color'] = theme[name] - trace = graph_objs.Scatter( - x=new_listx, - y=new_listy, - mode='markers', - name=name, - showlegend=False, - **kwargs - ) - else: - trace = graph_objs.Scatter( - x=new_listx, - y=new_listy, - mode='markers', - name=name, - marker=dict( - size=size, - color=theme[name]), - showlegend=False, - **kwargs - ) - # Push the trace into dictionary - unique_index_vals[name] = trace - trace_list.append(unique_index_vals) - legend_param += 1 - - trace_index = 0 - indices = range(1, dim + 1) - for y_index in indices: - for x_index in indices: - for name in sorted(trace_list[trace_index].keys()): - fig.append_trace( - trace_list[trace_index][name], - y_index, - x_index) - trace_index += 1 - - # Insert headers into the figure - for j in range(dim): - xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j) - fig['layout'][xaxis_key].update(title=headers[j]) - - for j in range(dim): - yaxis_key = 'yaxis{}'.format(1 + (dim * j)) - fig['layout'][yaxis_key].update(title=headers[j]) - - FigureFactory._hide_tick_labels_from_box_subplots(fig) - - if diag == 'histogram': - fig['layout'].update( - height=height, width=width, - title=title, - showlegend=True, - barmode='stack') - return fig - - else: - fig['layout'].update( - height=height, width=width, - title=title, - showlegend=True) - return fig + def create_trisurf(*args, **kwargs): + FigureFactory._deprecated('create_trisurf') + from plotly.figure_factory import create_trisurf + return create_trisurf(*args, **kwargs) @staticmethod - def _scatterplot_theme(dataframe, headers, diag, size, height, - width, title, index, index_vals, endpts, - colormap, colormap_type, **kwargs): - """ - Refer to FigureFactory.create_scatterplotmatrix() for docstring - - Returns fig for scatterplotmatrix with both index and colormap picked - - """ - from plotly.graph_objs import graph_objs - - # Check if index is made of string values - if isinstance(index_vals[0], str): - unique_index_vals = [] - for name in index_vals: - if name not in unique_index_vals: - unique_index_vals.append(name) - n_colors_len = len(unique_index_vals) - - # Convert colormap to list of n RGB tuples - if colormap_type == 'seq': - foo = FigureFactory._color_parser( - colormap, FigureFactory._unlabel_rgb - ) - foo = FigureFactory._n_colors(foo[0], - foo[1], - n_colors_len) - theme = FigureFactory._color_parser( - foo, FigureFactory._label_rgb - ) - - if colormap_type == 'cat': - # leave list of colors the same way - theme = colormap - - dim = len(dataframe) - fig = make_subplots(rows=dim, cols=dim, print_grid=False) - trace_list = [] - legend_param = 0 - # Work over all permutations of list pairs - for listy in dataframe: - for listx in dataframe: - # create a dictionary for index_vals - unique_index_vals = {} - for name in index_vals: - if name not in unique_index_vals: - unique_index_vals[name] = [] - - c_indx = 0 # color index - # Fill all the rest of the names into the dictionary - for name in sorted(unique_index_vals.keys()): - new_listx = [] - new_listy = [] - for j in range(len(index_vals)): - if index_vals[j] == name: - new_listx.append(listx[j]) - new_listy.append(listy[j]) - # Generate trace with VISIBLE icon - if legend_param == 1: - if (listx == listy) and (diag == 'histogram'): - trace = graph_objs.Histogram( - x=new_listx, - marker=dict( - color=theme[c_indx]), - showlegend=True - ) - elif (listx == listy) and (diag == 'box'): - trace = graph_objs.Box( - y=new_listx, - name=None, - marker=dict( - color=theme[c_indx]), - showlegend=True - ) - else: - if 'marker' in kwargs: - kwargs['marker']['size'] = size - kwargs['marker']['color'] = theme[c_indx] - trace = graph_objs.Scatter( - x=new_listx, - y=new_listy, - mode='markers', - name=name, - showlegend=True, - **kwargs - ) - else: - trace = graph_objs.Scatter( - x=new_listx, - y=new_listy, - mode='markers', - name=name, - marker=dict( - size=size, - color=theme[c_indx]), - showlegend=True, - **kwargs - ) - # Generate trace with INVISIBLE icon - else: - if (listx == listy) and (diag == 'histogram'): - trace = graph_objs.Histogram( - x=new_listx, - marker=dict( - color=theme[c_indx]), - showlegend=False - ) - elif (listx == listy) and (diag == 'box'): - trace = graph_objs.Box( - y=new_listx, - name=None, - marker=dict( - color=theme[c_indx]), - showlegend=False - ) - else: - if 'marker' in kwargs: - kwargs['marker']['size'] = size - kwargs['marker']['color'] = theme[c_indx] - trace = graph_objs.Scatter( - x=new_listx, - y=new_listy, - mode='markers', - name=name, - showlegend=False, - **kwargs - ) - else: - trace = graph_objs.Scatter( - x=new_listx, - y=new_listy, - mode='markers', - name=name, - marker=dict( - size=size, - color=theme[c_indx]), - showlegend=False, - **kwargs - ) - # Push the trace into dictionary - unique_index_vals[name] = trace - if c_indx >= (len(theme) - 1): - c_indx = -1 - c_indx += 1 - trace_list.append(unique_index_vals) - legend_param += 1 - - trace_index = 0 - indices = range(1, dim + 1) - for y_index in indices: - for x_index in indices: - for name in sorted(trace_list[trace_index].keys()): - fig.append_trace( - trace_list[trace_index][name], - y_index, - x_index) - trace_index += 1 - - # Insert headers into the figure - for j in range(dim): - xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j) - fig['layout'][xaxis_key].update(title=headers[j]) - - for j in range(dim): - yaxis_key = 'yaxis{}'.format(1 + (dim * j)) - fig['layout'][yaxis_key].update(title=headers[j]) - - FigureFactory._hide_tick_labels_from_box_subplots(fig) - - if diag == 'histogram': - fig['layout'].update( - height=height, width=width, - title=title, - showlegend=True, - barmode='stack') - return fig - - elif diag == 'box': - fig['layout'].update( - height=height, width=width, - title=title, - showlegend=True) - return fig - - else: - fig['layout'].update( - height=height, width=width, - title=title, - showlegend=True) - return fig - - else: - if endpts: - intervals = FigureFactory._endpts_to_intervals(endpts) - - # Convert colormap to list of n RGB tuples - if colormap_type == 'seq': - foo = FigureFactory._color_parser( - colormap, FigureFactory._unlabel_rgb - ) - foo = FigureFactory._n_colors(foo[0], - foo[1], - len(intervals)) - theme = FigureFactory._color_parser( - foo, FigureFactory._label_rgb - ) - - if colormap_type == 'cat': - # leave list of colors the same way - theme = colormap - - dim = len(dataframe) - fig = make_subplots(rows=dim, cols=dim, print_grid=False) - trace_list = [] - legend_param = 0 - # Work over all permutations of list pairs - for listy in dataframe: - for listx in dataframe: - interval_labels = {} - for interval in intervals: - interval_labels[str(interval)] = [] - - c_indx = 0 # color index - # Fill all the rest of the names into the dictionary - for interval in intervals: - new_listx = [] - new_listy = [] - for j in range(len(index_vals)): - if interval[0] < index_vals[j] <= interval[1]: - new_listx.append(listx[j]) - new_listy.append(listy[j]) - # Generate trace with VISIBLE icon - if legend_param == 1: - if (listx == listy) and (diag == 'histogram'): - trace = graph_objs.Histogram( - x=new_listx, - marker=dict( - color=theme[c_indx]), - showlegend=True - ) - elif (listx == listy) and (diag == 'box'): - trace = graph_objs.Box( - y=new_listx, - name=None, - marker=dict( - color=theme[c_indx]), - showlegend=True - ) - else: - if 'marker' in kwargs: - kwargs['marker']['size'] = size - (kwargs['marker'] - ['color']) = theme[c_indx] - trace = graph_objs.Scatter( - x=new_listx, - y=new_listy, - mode='markers', - name=str(interval), - showlegend=True, - **kwargs - ) - else: - trace = graph_objs.Scatter( - x=new_listx, - y=new_listy, - mode='markers', - name=str(interval), - marker=dict( - size=size, - color=theme[c_indx]), - showlegend=True, - **kwargs - ) - # Generate trace with INVISIBLE icon - else: - if (listx == listy) and (diag == 'histogram'): - trace = graph_objs.Histogram( - x=new_listx, - marker=dict( - color=theme[c_indx]), - showlegend=False - ) - elif (listx == listy) and (diag == 'box'): - trace = graph_objs.Box( - y=new_listx, - name=None, - marker=dict( - color=theme[c_indx]), - showlegend=False - ) - else: - if 'marker' in kwargs: - kwargs['marker']['size'] = size - (kwargs['marker'] - ['color']) = theme[c_indx] - trace = graph_objs.Scatter( - x=new_listx, - y=new_listy, - mode='markers', - name=str(interval), - showlegend=False, - **kwargs - ) - else: - trace = graph_objs.Scatter( - x=new_listx, - y=new_listy, - mode='markers', - name=str(interval), - marker=dict( - size=size, - color=theme[c_indx]), - showlegend=False, - **kwargs - ) - # Push the trace into dictionary - interval_labels[str(interval)] = trace - if c_indx >= (len(theme) - 1): - c_indx = -1 - c_indx += 1 - trace_list.append(interval_labels) - legend_param += 1 - - trace_index = 0 - indices = range(1, dim + 1) - for y_index in indices: - for x_index in indices: - for interval in intervals: - fig.append_trace( - trace_list[trace_index][str(interval)], - y_index, - x_index) - trace_index += 1 - - # Insert headers into the figure - for j in range(dim): - xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j) - fig['layout'][xaxis_key].update(title=headers[j]) - for j in range(dim): - yaxis_key = 'yaxis{}'.format(1 + (dim * j)) - fig['layout'][yaxis_key].update(title=headers[j]) - - FigureFactory._hide_tick_labels_from_box_subplots(fig) - - if diag == 'histogram': - fig['layout'].update( - height=height, width=width, - title=title, - showlegend=True, - barmode='stack') - return fig - - elif diag == 'box': - fig['layout'].update( - height=height, width=width, - title=title, - showlegend=True) - return fig - - else: - fig['layout'].update( - height=height, width=width, - title=title, - showlegend=True) - return fig - - else: - theme = colormap - - # add a copy of rgb color to theme if it contains one color - if len(theme) <= 1: - theme.append(theme[0]) - - color = [] - for incr in range(len(theme)): - color.append([1./(len(theme)-1)*incr, theme[incr]]) - - dim = len(dataframe) - fig = make_subplots(rows=dim, cols=dim, print_grid=False) - trace_list = [] - legend_param = 0 - # Run through all permutations of list pairs - for listy in dataframe: - for listx in dataframe: - # Generate trace with VISIBLE icon - if legend_param == 1: - if (listx == listy) and (diag == 'histogram'): - trace = graph_objs.Histogram( - x=listx, - marker=dict( - color=theme[0]), - showlegend=False - ) - elif (listx == listy) and (diag == 'box'): - trace = graph_objs.Box( - y=listx, - marker=dict( - color=theme[0]), - showlegend=False - ) - else: - if 'marker' in kwargs: - kwargs['marker']['size'] = size - kwargs['marker']['color'] = index_vals - kwargs['marker']['colorscale'] = color - kwargs['marker']['showscale'] = True - trace = graph_objs.Scatter( - x=listx, - y=listy, - mode='markers', - showlegend=False, - **kwargs - ) - else: - trace = graph_objs.Scatter( - x=listx, - y=listy, - mode='markers', - marker=dict( - size=size, - color=index_vals, - colorscale=color, - showscale=True), - showlegend=False, - **kwargs - ) - # Generate trace with INVISIBLE icon - else: - if (listx == listy) and (diag == 'histogram'): - trace = graph_objs.Histogram( - x=listx, - marker=dict( - color=theme[0]), - showlegend=False - ) - elif (listx == listy) and (diag == 'box'): - trace = graph_objs.Box( - y=listx, - marker=dict( - color=theme[0]), - showlegend=False - ) - else: - if 'marker' in kwargs: - kwargs['marker']['size'] = size - kwargs['marker']['color'] = index_vals - kwargs['marker']['colorscale'] = color - kwargs['marker']['showscale'] = False - trace = graph_objs.Scatter( - x=listx, - y=listy, - mode='markers', - showlegend=False, - **kwargs - ) - else: - trace = graph_objs.Scatter( - x=listx, - y=listy, - mode='markers', - marker=dict( - size=size, - color=index_vals, - colorscale=color, - showscale=False), - showlegend=False, - **kwargs - ) - # Push the trace into list - trace_list.append(trace) - legend_param += 1 - - trace_index = 0 - indices = range(1, dim + 1) - for y_index in indices: - for x_index in indices: - fig.append_trace(trace_list[trace_index], - y_index, - x_index) - trace_index += 1 - - # Insert headers into the figure - for j in range(dim): - xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j) - fig['layout'][xaxis_key].update(title=headers[j]) - for j in range(dim): - yaxis_key = 'yaxis{}'.format(1 + (dim * j)) - fig['layout'][yaxis_key].update(title=headers[j]) - - FigureFactory._hide_tick_labels_from_box_subplots(fig) - - if diag == 'histogram': - fig['layout'].update( - height=height, width=width, - title=title, - showlegend=True, - barmode='stack') - return fig - - elif diag == 'box': - fig['layout'].update( - height=height, width=width, - title=title, - showlegend=True) - return fig - - else: - fig['layout'].update( - height=height, width=width, - title=title, - showlegend=True) - return fig - - @staticmethod - def _hide_tick_labels_from_box_subplots(fig): - """ - Hides tick labels for box plots in scatterplotmatrix subplots. - """ - boxplot_xaxes = [] - for trace in fig['data']: - if trace['type'] == 'box': - # stores the xaxes which correspond to boxplot subplots - # since we use xaxis1, xaxis2, etc, in plotly.py - boxplot_xaxes.append( - 'xaxis{}'.format(trace['xaxis'][1:]) - ) - for xaxis in boxplot_xaxes: - fig['layout'][xaxis]['showticklabels'] = False - - @staticmethod - def _validate_index(index_vals): - """ - Validates if a list contains all numbers or all strings - - :raises: (PlotlyError) If there are any two items in the list whose - types differ - """ - from numbers import Number - if isinstance(index_vals[0], Number): - if not all(isinstance(item, Number) for item in index_vals): - raise exceptions.PlotlyError("Error in indexing column. " - "Make sure all entries of each " - "column are all numbers or " - "all strings.") - - elif isinstance(index_vals[0], str): - if not all(isinstance(item, str) for item in index_vals): - raise exceptions.PlotlyError("Error in indexing column. " - "Make sure all entries of each " - "column are all numbers or " - "all strings.") - - @staticmethod - def _validate_dataframe(array): - """ - Validates all strings or numbers in each dataframe column - - :raises: (PlotlyError) If there are any two items in any list whose - types differ - """ - from numbers import Number - for vector in array: - if isinstance(vector[0], Number): - if not all(isinstance(item, Number) for item in vector): - raise exceptions.PlotlyError("Error in dataframe. " - "Make sure all entries of " - "each column are either " - "numbers or strings.") - elif isinstance(vector[0], str): - if not all(isinstance(item, str) for item in vector): - raise exceptions.PlotlyError("Error in dataframe. " - "Make sure all entries of " - "each column are either " - "numbers or strings.") - - @staticmethod - def _validate_scatterplotmatrix(df, index, diag, colormap_type, **kwargs): - """ - Validates basic inputs for FigureFactory.create_scatterplotmatrix() - - :raises: (PlotlyError) If pandas is not imported - :raises: (PlotlyError) If pandas dataframe is not inputted - :raises: (PlotlyError) If pandas dataframe has <= 1 columns - :raises: (PlotlyError) If diagonal plot choice (diag) is not one of - the viable options - :raises: (PlotlyError) If colormap_type is not a valid choice - :raises: (PlotlyError) If kwargs contains 'size', 'color' or - 'colorscale' - """ - if _pandas_imported is False: - raise ImportError("FigureFactory.scatterplotmatrix requires " - "a pandas DataFrame.") - - # Check if pandas dataframe - if not isinstance(df, pd.core.frame.DataFrame): - raise exceptions.PlotlyError("Dataframe not inputed. Please " - "use a pandas dataframe to pro" - "duce a scatterplot matrix.") - - # Check if dataframe is 1 column or less - if len(df.columns) <= 1: - raise exceptions.PlotlyError("Dataframe has only one column. To " - "use the scatterplot matrix, use at " - "least 2 columns.") - - # Check that diag parameter is a valid selection - if diag not in DIAG_CHOICES: - raise exceptions.PlotlyError("Make sure diag is set to " - "one of {}".format(DIAG_CHOICES)) - - # Check that colormap_types is a valid selection - if colormap_type not in VALID_COLORMAP_TYPES: - raise exceptions.PlotlyError("Must choose a valid colormap type. " - "Either 'cat' or 'seq' for a cate" - "gorical and sequential colormap " - "respectively.") - - # Check for not 'size' or 'color' in 'marker' of **kwargs - if 'marker' in kwargs: - FORBIDDEN_PARAMS = ['size', 'color', 'colorscale'] - if any(param in kwargs['marker'] for param in FORBIDDEN_PARAMS): - raise exceptions.PlotlyError("Your kwargs dictionary cannot " - "include the 'size', 'color' or " - "'colorscale' key words inside " - "the marker dict since 'size' is " - "already an argument of the " - "scatterplot matrix function and " - "both 'color' and 'colorscale " - "are set internally.") - - @staticmethod - def _endpts_to_intervals(endpts): - """ - Returns a list of intervals for categorical colormaps - - Accepts a list or tuple of sequentially increasing numbers and returns - a list representation of the mathematical intervals with these numbers - as endpoints. For example, [1, 6] returns [[-inf, 1], [1, 6], [6, inf]] - - :raises: (PlotlyError) If input is not a list or tuple - :raises: (PlotlyError) If the input contains a string - :raises: (PlotlyError) If any number does not increase after the - previous one in the sequence - """ - length = len(endpts) - # Check if endpts is a list or tuple - if not (isinstance(endpts, (tuple)) or isinstance(endpts, (list))): - raise exceptions.PlotlyError("The intervals_endpts argument must " - "be a list or tuple of a sequence " - "of increasing numbers.") - # Check if endpts contains only numbers - for item in endpts: - if isinstance(item, str): - raise exceptions.PlotlyError("The intervals_endpts argument " - "must be a list or tuple of a " - "sequence of increasing " - "numbers.") - # Check if numbers in endpts are increasing - for k in range(length-1): - if endpts[k] >= endpts[k+1]: - raise exceptions.PlotlyError("The intervals_endpts argument " - "must be a list or tuple of a " - "sequence of increasing " - "numbers.") - else: - intervals = [] - # add -inf to intervals - intervals.append([float('-inf'), endpts[0]]) - for k in range(length - 1): - interval = [] - interval.append(endpts[k]) - interval.append(endpts[k + 1]) - intervals.append(interval) - # add +inf to intervals - intervals.append([endpts[length - 1], float('inf')]) - return intervals - - @staticmethod - def _convert_to_RGB_255(colors): - """ - Multiplies each element of a triplet by 255 - - Each coordinate of the color tuple is rounded to the nearest float and - then is turned into an integer. If a number is of the form x.5, then - if x is odd, the number rounds up to (x+1). Otherwise, it rounds down - to just x. This is the way rounding works in Python 3 and in current - statistical analysis to avoid rounding bias - """ - rgb_components = [] - - for component in colors: - rounded_num = decimal.Decimal(str(component*255.0)).quantize( - decimal.Decimal('1'), rounding=decimal.ROUND_HALF_EVEN - ) - # convert rounded number to an integer from 'Decimal' form - rounded_num = int(rounded_num) - rgb_components.append(rounded_num) - - return (rgb_components[0], rgb_components[1], rgb_components[2]) - - @staticmethod - def _n_colors(lowcolor, highcolor, n_colors): - """ - Splits a low and high color into a list of n_colors colors in it - - Accepts two color tuples and returns a list of n_colors colors - which form the intermediate colors between lowcolor and highcolor - from linearly interpolating through RGB space - - """ - diff_0 = float(highcolor[0] - lowcolor[0]) - incr_0 = diff_0/(n_colors - 1) - diff_1 = float(highcolor[1] - lowcolor[1]) - incr_1 = diff_1/(n_colors - 1) - diff_2 = float(highcolor[2] - lowcolor[2]) - incr_2 = diff_2/(n_colors - 1) - color_tuples = [] - - for index in range(n_colors): - new_tuple = (lowcolor[0] + (index * incr_0), - lowcolor[1] + (index * incr_1), - lowcolor[2] + (index * incr_2)) - color_tuples.append(new_tuple) - - return color_tuples - - @staticmethod - def _label_rgb(colors): - """ - Takes tuple (a, b, c) and returns an rgb color 'rgb(a, b, c)' - """ - return ('rgb(%s, %s, %s)' % (colors[0], colors[1], colors[2])) - - @staticmethod - def _unlabel_rgb(colors): - """ - Takes rgb color(s) 'rgb(a, b, c)' and returns tuple(s) (a, b, c) - - This function takes either an 'rgb(a, b, c)' color or a list of - such colors and returns the color tuples in tuple(s) (a, b, c) - - """ - str_vals = '' - for index in range(len(colors)): - try: - float(colors[index]) - str_vals = str_vals + colors[index] - except ValueError: - if colors[index] == ',' or colors[index] == '.': - str_vals = str_vals + colors[index] - - str_vals = str_vals + ',' - numbers = [] - str_num = '' - for char in str_vals: - if char != ',': - str_num = str_num + char - else: - numbers.append(float(str_num)) - str_num = '' - return (numbers[0], numbers[1], numbers[2]) - - @staticmethod - def create_scatterplotmatrix(df, index=None, endpts=None, diag='scatter', - height=500, width=500, size=6, - title='Scatterplot Matrix', colormap=None, - colormap_type='cat', dataframe=None, - headers=None, index_vals=None, **kwargs): - """ - Returns data for a scatterplot matrix. - - :param (array) df: array of the data with column headers - :param (str) index: name of the index column in data array - :param (list|tuple) endpts: takes an increasing sequece of numbers - that defines intervals on the real line. They are used to group - the entries in an index of numbers into their corresponding - interval and therefore can be treated as categorical data - :param (str) diag: sets the chart type for the main diagonal plots. - The options are 'scatter', 'histogram' and 'box'. - :param (int|float) height: sets the height of the chart - :param (int|float) width: sets the width of the chart - :param (float) size: sets the marker size (in px) - :param (str) title: the title label of the scatterplot matrix - :param (str|tuple|list|dict) colormap: either a plotly scale name, - an rgb or hex color, a color tuple, a list of colors or a - dictionary. An rgb color is of the form 'rgb(x, y, z)' where - x, y and z belong to the interval [0, 255] and a color tuple is a - tuple of the form (a, b, c) where a, b and c belong to [0, 1]. - If colormap is a list, it must contain valid color types as its - members. - If colormap is a dictionary, all the string entries in - the index column must be a key in colormap. In this case, the - colormap_type is forced to 'cat' or categorical - :param (str) colormap_type: determines how colormap is interpreted. - Valid choices are 'seq' (sequential) and 'cat' (categorical). If - 'seq' is selected, only the first two colors in colormap will be - considered (when colormap is a list) and the index values will be - linearly interpolated between those two colors. This option is - forced if all index values are numeric. - If 'cat' is selected, a color from colormap will be assigned to - each category from index, including the intervals if endpts is - being used - :param (dict) **kwargs: a dictionary of scatterplot arguments - The only forbidden parameters are 'size', 'color' and - 'colorscale' in 'marker' - - Example 1: Vanilla Scatterplot Matrix - ``` - import plotly.plotly as py - from plotly.graph_objs import graph_objs - from plotly.tools import FigureFactory as FF - - import numpy as np - import pandas as pd - - # Create dataframe - df = pd.DataFrame(np.random.randn(10, 2), - columns=['Column 1', 'Column 2']) - - # Create scatterplot matrix - fig = FF.create_scatterplotmatrix(df) - - # Plot - py.iplot(fig, filename='Vanilla Scatterplot Matrix') - ``` - - Example 2: Indexing a Column - ``` - import plotly.plotly as py - from plotly.graph_objs import graph_objs - from plotly.tools import FigureFactory as FF - - import numpy as np - import pandas as pd - - # Create dataframe with index - df = pd.DataFrame(np.random.randn(10, 2), - columns=['A', 'B']) - - # Add another column of strings to the dataframe - df['Fruit'] = pd.Series(['apple', 'apple', 'grape', 'apple', 'apple', - 'grape', 'pear', 'pear', 'apple', 'pear']) - - # Create scatterplot matrix - fig = FF.create_scatterplotmatrix(df, index='Fruit', size=10) - - # Plot - py.iplot(fig, filename = 'Scatterplot Matrix with Index') - ``` - - Example 3: Styling the Diagonal Subplots - ``` - import plotly.plotly as py - from plotly.graph_objs import graph_objs - from plotly.tools import FigureFactory as FF - - import numpy as np - import pandas as pd - - # Create dataframe with index - df = pd.DataFrame(np.random.randn(10, 4), - columns=['A', 'B', 'C', 'D']) - - # Add another column of strings to the dataframe - df['Fruit'] = pd.Series(['apple', 'apple', 'grape', 'apple', 'apple', - 'grape', 'pear', 'pear', 'apple', 'pear']) - - # Create scatterplot matrix - fig = FF.create_scatterplotmatrix(df, diag='box', index='Fruit', - height=1000, width=1000) - - # Plot - py.iplot(fig, filename = 'Scatterplot Matrix - Diagonal Styling') - ``` - - Example 4: Use a Theme to Style the Subplots - ``` - import plotly.plotly as py - from plotly.graph_objs import graph_objs - from plotly.tools import FigureFactory as FF - - import numpy as np - import pandas as pd - - # Create dataframe with random data - df = pd.DataFrame(np.random.randn(100, 3), - columns=['A', 'B', 'C']) - - # Create scatterplot matrix using a built-in - # Plotly palette scale and indexing column 'A' - fig = FF.create_scatterplotmatrix(df, diag='histogram', - index='A', colormap='Blues', - height=800, width=800) - - # Plot - py.iplot(fig, filename = 'Scatterplot Matrix - Colormap Theme') - ``` - - Example 5: Example 4 with Interval Factoring - ``` - import plotly.plotly as py - from plotly.graph_objs import graph_objs - from plotly.tools import FigureFactory as FF - - import numpy as np - import pandas as pd - - # Create dataframe with random data - df = pd.DataFrame(np.random.randn(100, 3), - columns=['A', 'B', 'C']) - - # Create scatterplot matrix using a list of 2 rgb tuples - # and endpoints at -1, 0 and 1 - fig = FF.create_scatterplotmatrix(df, diag='histogram', index='A', - colormap=['rgb(140, 255, 50)', - 'rgb(170, 60, 115)', - '#6c4774', - (0.5, 0.1, 0.8)], - endpts=[-1, 0, 1], - height=800, width=800) - - # Plot - py.iplot(fig, filename = 'Scatterplot Matrix - Intervals') - ``` - - Example 6: Using the colormap as a Dictionary - ``` - import plotly.plotly as py - from plotly.graph_objs import graph_objs - from plotly.tools import FigureFactory as FF - - import numpy as np - import pandas as pd - import random - - # Create dataframe with random data - df = pd.DataFrame(np.random.randn(100, 3), - columns=['Column A', - 'Column B', - 'Column C']) - - # Add new color column to dataframe - new_column = [] - strange_colors = ['turquoise', 'limegreen', 'goldenrod'] - - for j in range(100): - new_column.append(random.choice(strange_colors)) - df['Colors'] = pd.Series(new_column, index=df.index) - - # Create scatterplot matrix using a dictionary of hex color values - # which correspond to actual color names in 'Colors' column - fig = FF.create_scatterplotmatrix( - df, diag='box', index='Colors', - colormap= dict( - turquoise = '#00F5FF', - limegreen = '#32CD32', - goldenrod = '#DAA520' - ), - colormap_type='cat', - height=800, width=800 - ) - - # Plot - py.iplot(fig, filename = 'Scatterplot Matrix - colormap dictionary ') - ``` - """ - # TODO: protected until #282 - if dataframe is None: - dataframe = [] - if headers is None: - headers = [] - if index_vals is None: - index_vals = [] - - FigureFactory._validate_scatterplotmatrix(df, index, diag, - colormap_type, **kwargs) - - # Validate colormap - if isinstance(colormap, dict): - colormap = FigureFactory._validate_colors_dict(colormap, 'rgb') - else: - colormap = FigureFactory._validate_colors(colormap, 'rgb') - - if not index: - for name in df: - headers.append(name) - for name in headers: - dataframe.append(df[name].values.tolist()) - # Check for same data-type in df columns - FigureFactory._validate_dataframe(dataframe) - figure = FigureFactory._scatterplot(dataframe, headers, diag, - size, height, width, title, - **kwargs) - return figure - else: - # Validate index selection - if index not in df: - raise exceptions.PlotlyError("Make sure you set the index " - "input variable to one of the " - "column names of your " - "dataframe.") - index_vals = df[index].values.tolist() - for name in df: - if name != index: - headers.append(name) - for name in headers: - dataframe.append(df[name].values.tolist()) - - # check for same data-type in each df column - FigureFactory._validate_dataframe(dataframe) - FigureFactory._validate_index(index_vals) - - # check if all colormap keys are in the index - # if colormap is a dictionary - if isinstance(colormap, dict): - for key in colormap: - if not all(index in colormap for index in index_vals): - raise exceptions.PlotlyError("If colormap is a " - "dictionary, all the " - "names in the index " - "must be keys.") - figure = FigureFactory._scatterplot_dict( - dataframe, headers, diag, size, height, width, title, - index, index_vals, endpts, colormap, colormap_type, - **kwargs - ) - return figure - - else: - figure = FigureFactory._scatterplot_theme( - dataframe, headers, diag, size, height, width, title, - index, index_vals, endpts, colormap, colormap_type, - **kwargs - ) - return figure - - @staticmethod - def _validate_equal_length(*args): - """ - Validates that data lists or ndarrays are the same length. - - :raises: (PlotlyError) If any data lists are not the same length. - """ - length = len(args[0]) - if any(len(lst) != length for lst in args): - raise exceptions.PlotlyError("Oops! Your data lists or ndarrays " - "should be the same length.") - - @staticmethod - def _validate_ohlc(open, high, low, close, direction, **kwargs): - """ - ohlc and candlestick specific validations - - Specifically, this checks that the high value is the greatest value and - the low value is the lowest value in each unit. - - See FigureFactory.create_ohlc() or FigureFactory.create_candlestick() - for params - - :raises: (PlotlyError) If the high value is not the greatest value in - each unit. - :raises: (PlotlyError) If the low value is not the lowest value in each - unit. - :raises: (PlotlyError) If direction is not 'increasing' or 'decreasing' - """ - for lst in [open, low, close]: - for index in range(len(high)): - if high[index] < lst[index]: - raise exceptions.PlotlyError("Oops! Looks like some of " - "your high values are less " - "the corresponding open, " - "low, or close values. " - "Double check that your data " - "is entered in O-H-L-C order") - - for lst in [open, high, close]: - for index in range(len(low)): - if low[index] > lst[index]: - raise exceptions.PlotlyError("Oops! Looks like some of " - "your low values are greater " - "than the corresponding high" - ", open, or close values. " - "Double check that your data " - "is entered in O-H-L-C order") - - direction_opts = ('increasing', 'decreasing', 'both') - if direction not in direction_opts: - raise exceptions.PlotlyError("direction must be defined as " - "'increasing', 'decreasing', or " - "'both'") - - @staticmethod - def _validate_distplot(hist_data, curve_type): - """ - Distplot-specific validations - - :raises: (PlotlyError) If hist_data is not a list of lists - :raises: (PlotlyError) If curve_type is not valid (i.e. not 'kde' or - 'normal'). - """ - try: - import pandas as pd - _pandas_imported = True - except ImportError: - _pandas_imported = False - - hist_data_types = (list,) - if _numpy_imported: - hist_data_types += (np.ndarray,) - if _pandas_imported: - hist_data_types += (pd.core.series.Series,) - - if not isinstance(hist_data[0], hist_data_types): - raise exceptions.PlotlyError("Oops, this function was written " - "to handle multiple datasets, if " - "you want to plot just one, make " - "sure your hist_data variable is " - "still a list of lists, i.e. x = " - "[1, 2, 3] -> x = [[1, 2, 3]]") - - curve_opts = ('kde', 'normal') - if curve_type not in curve_opts: - raise exceptions.PlotlyError("curve_type must be defined as " - "'kde' or 'normal'") - - if _scipy_imported is False: - raise ImportError("FigureFactory.create_distplot requires scipy") - - @staticmethod - def _validate_positive_scalars(**kwargs): - """ - Validates that all values given in key/val pairs are positive. - - Accepts kwargs to improve Exception messages. - - :raises: (PlotlyError) If any value is < 0 or raises. - """ - for key, val in kwargs.items(): - try: - if val <= 0: - raise ValueError('{} must be > 0, got {}'.format(key, val)) - except TypeError: - raise exceptions.PlotlyError('{} must be a number, got {}' - .format(key, val)) - - @staticmethod - def _validate_streamline(x, y): - """ - Streamline-specific validations - - Specifically, this checks that x and y are both evenly spaced, - and that the package numpy is available. - - See FigureFactory.create_streamline() for params - - :raises: (ImportError) If numpy is not available. - :raises: (PlotlyError) If x is not evenly spaced. - :raises: (PlotlyError) If y is not evenly spaced. - """ - if _numpy_imported is False: - raise ImportError("FigureFactory.create_streamline requires numpy") - for index in range(len(x) - 1): - if ((x[index + 1] - x[index]) - (x[1] - x[0])) > .0001: - raise exceptions.PlotlyError("x must be a 1 dimensional, " - "evenly spaced array") - for index in range(len(y) - 1): - if ((y[index + 1] - y[index]) - - (y[1] - y[0])) > .0001: - raise exceptions.PlotlyError("y must be a 1 dimensional, " - "evenly spaced array") - - @staticmethod - def _validate_annotated_heatmap(z, x, y, annotation_text): - """ - Annotated-heatmap-specific validations - - Check that if a text matrix is supplied, it has the same - dimensions as the z matrix. - - See FigureFactory.create_annotated_heatmap() for params - - :raises: (PlotlyError) If z and text matrices do not have the same - dimensions. - """ - if annotation_text is not None and isinstance(annotation_text, list): - FigureFactory._validate_equal_length(z, annotation_text) - for lst in range(len(z)): - if len(z[lst]) != len(annotation_text[lst]): - raise exceptions.PlotlyError("z and text should have the " - "same dimensions") - - if x: - if len(x) != len(z[0]): - raise exceptions.PlotlyError("oops, the x list that you " - "provided does not match the " - "width of your z matrix ") - - if y: - if len(y) != len(z): - raise exceptions.PlotlyError("oops, the y list that you " - "provided does not match the " - "length of your z matrix ") - - @staticmethod - def _validate_table(table_text, font_colors): - """ - Table-specific validations - - Check that font_colors is supplied correctly (1, 3, or len(text) - colors). - - :raises: (PlotlyError) If font_colors is supplied incorretly. - - See FigureFactory.create_table() for params - """ - font_colors_len_options = [1, 3, len(table_text)] - if len(font_colors) not in font_colors_len_options: - raise exceptions.PlotlyError("Oops, font_colors should be a list " - "of length 1, 3 or len(text)") - - @staticmethod - def _flatten(array): - """ - Uses list comprehension to flatten array - - :param (array): An iterable to flatten - :raises (PlotlyError): If iterable is not nested. - :rtype (list): The flattened list. - """ - try: - return [item for sublist in array for item in sublist] - except TypeError: - raise exceptions.PlotlyError("Your data array could not be " - "flattened! Make sure your data is " - "entered as lists or ndarrays!") - - @staticmethod - def _hex_to_rgb(value): - """ - Calculates rgb values from a hex color code. - - :param (string) value: Hex color string - - :rtype (tuple) (r_value, g_value, b_value): tuple of rgb values - """ - value = value.lstrip('#') - hex_total_length = len(value) - rgb_section_length = hex_total_length // 3 - return tuple(int(value[i:i + rgb_section_length], 16) - for i in range(0, hex_total_length, rgb_section_length)) - - @staticmethod - def create_quiver(x, y, u, v, scale=.1, arrow_scale=.3, - angle=math.pi / 9, **kwargs): - """ - Returns data for a quiver plot. - - :param (list|ndarray) x: x coordinates of the arrow locations - :param (list|ndarray) y: y coordinates of the arrow locations - :param (list|ndarray) u: x components of the arrow vectors - :param (list|ndarray) v: y components of the arrow vectors - :param (float in [0,1]) scale: scales size of the arrows(ideally to - avoid overlap). Default = .1 - :param (float in [0,1]) arrow_scale: value multiplied to length of barb - to get length of arrowhead. Default = .3 - :param (angle in radians) angle: angle of arrowhead. Default = pi/9 - :param kwargs: kwargs passed through plotly.graph_objs.Scatter - for more information on valid kwargs call - help(plotly.graph_objs.Scatter) - - :rtype (dict): returns a representation of quiver figure. - - Example 1: Trivial Quiver - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - - import math - - # 1 Arrow from (0,0) to (1,1) - fig = FF.create_quiver(x=[0], y=[0], - u=[1], v=[1], - scale=1) - - py.plot(fig, filename='quiver') - ``` - - Example 2: Quiver plot using meshgrid - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - - import numpy as np - import math - - # Add data - x,y = np.meshgrid(np.arange(0, 2, .2), np.arange(0, 2, .2)) - u = np.cos(x)*y - v = np.sin(x)*y - - #Create quiver - fig = FF.create_quiver(x, y, u, v) - - # Plot - py.plot(fig, filename='quiver') - ``` - - Example 3: Styling the quiver plot - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - import numpy as np - import math - - # Add data - x, y = np.meshgrid(np.arange(-np.pi, math.pi, .5), - np.arange(-math.pi, math.pi, .5)) - u = np.cos(x)*y - v = np.sin(x)*y - - # Create quiver - fig = FF.create_quiver(x, y, u, v, scale=.2, - arrow_scale=.3, - angle=math.pi/6, - name='Wind Velocity', - line=Line(width=1)) - - # Add title to layout - fig['layout'].update(title='Quiver Plot') - - # Plot - py.plot(fig, filename='quiver') - ``` - """ - # TODO: protected until #282 - from plotly.graph_objs import graph_objs - FigureFactory._validate_equal_length(x, y, u, v) - FigureFactory._validate_positive_scalars(arrow_scale=arrow_scale, - scale=scale) - - barb_x, barb_y = _Quiver(x, y, u, v, scale, - arrow_scale, angle).get_barbs() - arrow_x, arrow_y = _Quiver(x, y, u, v, scale, - arrow_scale, angle).get_quiver_arrows() - quiver = graph_objs.Scatter(x=barb_x + arrow_x, - y=barb_y + arrow_y, - mode='lines', **kwargs) - - data = [quiver] - layout = graph_objs.Layout(hovermode='closest') - - return graph_objs.Figure(data=data, layout=layout) - - @staticmethod - def create_streamline(x, y, u, v, - density=1, angle=math.pi / 9, - arrow_scale=.09, **kwargs): - """ - Returns data for a streamline plot. - - :param (list|ndarray) x: 1 dimensional, evenly spaced list or array - :param (list|ndarray) y: 1 dimensional, evenly spaced list or array - :param (ndarray) u: 2 dimensional array - :param (ndarray) v: 2 dimensional array - :param (float|int) density: controls the density of streamlines in - plot. This is multiplied by 30 to scale similiarly to other - available streamline functions such as matplotlib. - Default = 1 - :param (angle in radians) angle: angle of arrowhead. Default = pi/9 - :param (float in [0,1]) arrow_scale: value to scale length of arrowhead - Default = .09 - :param kwargs: kwargs passed through plotly.graph_objs.Scatter - for more information on valid kwargs call - help(plotly.graph_objs.Scatter) - - :rtype (dict): returns a representation of streamline figure. - - Example 1: Plot simple streamline and increase arrow size - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - - import numpy as np - import math - - # Add data - x = np.linspace(-3, 3, 100) - y = np.linspace(-3, 3, 100) - Y, X = np.meshgrid(x, y) - u = -1 - X**2 + Y - v = 1 + X - Y**2 - u = u.T # Transpose - v = v.T # Transpose - - # Create streamline - fig = FF.create_streamline(x, y, u, v, - arrow_scale=.1) - - # Plot - py.plot(fig, filename='streamline') - ``` - - Example 2: from nbviewer.ipython.org/github/barbagroup/AeroPython - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - - import numpy as np - import math - - # Add data - N = 50 - x_start, x_end = -2.0, 2.0 - y_start, y_end = -1.0, 1.0 - x = np.linspace(x_start, x_end, N) - y = np.linspace(y_start, y_end, N) - X, Y = np.meshgrid(x, y) - ss = 5.0 - x_s, y_s = -1.0, 0.0 - - # Compute the velocity field on the mesh grid - u_s = ss/(2*np.pi) * (X-x_s)/((X-x_s)**2 + (Y-y_s)**2) - v_s = ss/(2*np.pi) * (Y-y_s)/((X-x_s)**2 + (Y-y_s)**2) - - # Create streamline - fig = FF.create_streamline(x, y, u_s, v_s, - density=2, name='streamline') - - # Add source point - point = Scatter(x=[x_s], y=[y_s], mode='markers', - marker=Marker(size=14), name='source point') - - # Plot - fig['data'].append(point) - py.plot(fig, filename='streamline') - ``` - """ - # TODO: protected until #282 - from plotly.graph_objs import graph_objs - FigureFactory._validate_equal_length(x, y) - FigureFactory._validate_equal_length(u, v) - FigureFactory._validate_streamline(x, y) - FigureFactory._validate_positive_scalars(density=density, - arrow_scale=arrow_scale) - - streamline_x, streamline_y = _Streamline(x, y, u, v, - density, angle, - arrow_scale).sum_streamlines() - arrow_x, arrow_y = _Streamline(x, y, u, v, - density, angle, - arrow_scale).get_streamline_arrows() - - streamline = graph_objs.Scatter(x=streamline_x + arrow_x, - y=streamline_y + arrow_y, - mode='lines', **kwargs) - - data = [streamline] - layout = graph_objs.Layout(hovermode='closest') - - return graph_objs.Figure(data=data, layout=layout) - - @staticmethod - def _make_increasing_ohlc(open, high, low, close, dates, **kwargs): - """ - Makes increasing ohlc sticks - - _make_increasing_ohlc() and _make_decreasing_ohlc separate the - increasing trace from the decreasing trace so kwargs (such as - color) can be passed separately to increasing or decreasing traces - when direction is set to 'increasing' or 'decreasing' in - FigureFactory.create_candlestick() - - :param (list) open: opening values - :param (list) high: high values - :param (list) low: low values - :param (list) close: closing values - :param (list) dates: list of datetime objects. Default: None - :param kwargs: kwargs to be passed to increasing trace via - plotly.graph_objs.Scatter. - - :rtype (trace) ohlc_incr_data: Scatter trace of all increasing ohlc - sticks. - """ - (flat_increase_x, - flat_increase_y, - text_increase) = _OHLC(open, high, low, close, dates).get_increase() - - if 'name' in kwargs: - showlegend = True - else: - kwargs.setdefault('name', 'Increasing') - showlegend = False - - kwargs.setdefault('line', dict(color=_DEFAULT_INCREASING_COLOR, - width=1)) - kwargs.setdefault('text', text_increase) - - ohlc_incr = dict(type='scatter', - x=flat_increase_x, - y=flat_increase_y, - mode='lines', - showlegend=showlegend, - **kwargs) - return ohlc_incr - - @staticmethod - def _make_decreasing_ohlc(open, high, low, close, dates, **kwargs): - """ - Makes decreasing ohlc sticks - - :param (list) open: opening values - :param (list) high: high values - :param (list) low: low values - :param (list) close: closing values - :param (list) dates: list of datetime objects. Default: None - :param kwargs: kwargs to be passed to increasing trace via - plotly.graph_objs.Scatter. - - :rtype (trace) ohlc_decr_data: Scatter trace of all decreasing ohlc - sticks. - """ - (flat_decrease_x, - flat_decrease_y, - text_decrease) = _OHLC(open, high, low, close, dates).get_decrease() - - kwargs.setdefault('line', dict(color=_DEFAULT_DECREASING_COLOR, - width=1)) - kwargs.setdefault('text', text_decrease) - kwargs.setdefault('showlegend', False) - kwargs.setdefault('name', 'Decreasing') - - ohlc_decr = dict(type='scatter', - x=flat_decrease_x, - y=flat_decrease_y, - mode='lines', - **kwargs) - return ohlc_decr - - @staticmethod - def create_ohlc(open, high, low, close, - dates=None, direction='both', - **kwargs): - """ - BETA function that creates an ohlc chart - - :param (list) open: opening values - :param (list) high: high values - :param (list) low: low values - :param (list) close: closing - :param (list) dates: list of datetime objects. Default: None - :param (string) direction: direction can be 'increasing', 'decreasing', - or 'both'. When the direction is 'increasing', the returned figure - consists of all units where the close value is greater than the - corresponding open value, and when the direction is 'decreasing', - the returned figure consists of all units where the close value is - less than or equal to the corresponding open value. When the - direction is 'both', both increasing and decreasing units are - returned. Default: 'both' - :param kwargs: kwargs passed through plotly.graph_objs.Scatter. - These kwargs describe other attributes about the ohlc Scatter trace - such as the color or the legend name. For more information on valid - kwargs call help(plotly.graph_objs.Scatter) - - :rtype (dict): returns a representation of an ohlc chart figure. - - Example 1: Simple OHLC chart from a Pandas DataFrame - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - from datetime import datetime - - import pandas.io.data as web - - df = web.DataReader("aapl", 'yahoo', datetime(2008, 8, 15), datetime(2008, 10, 15)) - fig = FF.create_ohlc(df.Open, df.High, df.Low, df.Close, dates=df.index) - - py.plot(fig, filename='finance/aapl-ohlc') - ``` - - Example 2: Add text and annotations to the OHLC chart - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - from datetime import datetime - - import pandas.io.data as web - - df = web.DataReader("aapl", 'yahoo', datetime(2008, 8, 15), datetime(2008, 10, 15)) - fig = FF.create_ohlc(df.Open, df.High, df.Low, df.Close, dates=df.index) - - # Update the fig - all options here: https://plot.ly/python/reference/#Layout - fig['layout'].update({ - 'title': 'The Great Recession', - 'yaxis': {'title': 'AAPL Stock'}, - 'shapes': [{ - 'x0': '2008-09-15', 'x1': '2008-09-15', 'type': 'line', - 'y0': 0, 'y1': 1, 'xref': 'x', 'yref': 'paper', - 'line': {'color': 'rgb(40,40,40)', 'width': 0.5} - }], - 'annotations': [{ - 'text': "the fall of Lehman Brothers", - 'x': '2008-09-15', 'y': 1.02, - 'xref': 'x', 'yref': 'paper', - 'showarrow': False, 'xanchor': 'left' - }] - }) - - py.plot(fig, filename='finance/aapl-recession-ohlc', validate=False) - ``` - - Example 3: Customize the OHLC colors - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - from plotly.graph_objs import Line, Marker - from datetime import datetime - - import pandas.io.data as web - - df = web.DataReader("aapl", 'yahoo', datetime(2008, 1, 1), datetime(2009, 4, 1)) - - # Make increasing ohlc sticks and customize their color and name - fig_increasing = FF.create_ohlc(df.Open, df.High, df.Low, df.Close, dates=df.index, - direction='increasing', name='AAPL', - line=Line(color='rgb(150, 200, 250)')) - - # Make decreasing ohlc sticks and customize their color and name - fig_decreasing = FF.create_ohlc(df.Open, df.High, df.Low, df.Close, dates=df.index, - direction='decreasing', - line=Line(color='rgb(128, 128, 128)')) - - # Initialize the figure - fig = fig_increasing - - # Add decreasing data with .extend() - fig['data'].extend(fig_decreasing['data']) - - py.iplot(fig, filename='finance/aapl-ohlc-colors', validate=False) - ``` - - Example 4: OHLC chart with datetime objects - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - - from datetime import datetime - - # Add data - open_data = [33.0, 33.3, 33.5, 33.0, 34.1] - high_data = [33.1, 33.3, 33.6, 33.2, 34.8] - low_data = [32.7, 32.7, 32.8, 32.6, 32.8] - close_data = [33.0, 32.9, 33.3, 33.1, 33.1] - dates = [datetime(year=2013, month=10, day=10), - datetime(year=2013, month=11, day=10), - datetime(year=2013, month=12, day=10), - datetime(year=2014, month=1, day=10), - datetime(year=2014, month=2, day=10)] - - # Create ohlc - fig = FF.create_ohlc(open_data, high_data, - low_data, close_data, dates=dates) - - py.iplot(fig, filename='finance/simple-ohlc', validate=False) - ``` - """ - # TODO: protected until #282 - from plotly.graph_objs import graph_objs - if dates is not None: - FigureFactory._validate_equal_length(open, high, low, close, dates) - else: - FigureFactory._validate_equal_length(open, high, low, close) - FigureFactory._validate_ohlc(open, high, low, close, direction, - **kwargs) - - if direction is 'increasing': - ohlc_incr = FigureFactory._make_increasing_ohlc(open, high, - low, close, - dates, **kwargs) - data = [ohlc_incr] - elif direction is 'decreasing': - ohlc_decr = FigureFactory._make_decreasing_ohlc(open, high, - low, close, - dates, **kwargs) - data = [ohlc_decr] - else: - ohlc_incr = FigureFactory._make_increasing_ohlc(open, high, - low, close, - dates, **kwargs) - ohlc_decr = FigureFactory._make_decreasing_ohlc(open, high, - low, close, - dates, **kwargs) - data = [ohlc_incr, ohlc_decr] - - layout = graph_objs.Layout(xaxis=dict(zeroline=False), - hovermode='closest') - - return graph_objs.Figure(data=data, layout=layout) - - @staticmethod - def _make_increasing_candle(open, high, low, close, dates, **kwargs): - """ - Makes boxplot trace for increasing candlesticks - - _make_increasing_candle() and _make_decreasing_candle separate the - increasing traces from the decreasing traces so kwargs (such as - color) can be passed separately to increasing or decreasing traces - when direction is set to 'increasing' or 'decreasing' in - FigureFactory.create_candlestick() - - :param (list) open: opening values - :param (list) high: high values - :param (list) low: low values - :param (list) close: closing values - :param (list) dates: list of datetime objects. Default: None - :param kwargs: kwargs to be passed to increasing trace via - plotly.graph_objs.Scatter. - - :rtype (list) candle_incr_data: list of the box trace for - increasing candlesticks. - """ - increase_x, increase_y = _Candlestick( - open, high, low, close, dates, **kwargs).get_candle_increase() - - if 'line' in kwargs: - kwargs.setdefault('fillcolor', kwargs['line']['color']) - else: - kwargs.setdefault('fillcolor', _DEFAULT_INCREASING_COLOR) - if 'name' in kwargs: - kwargs.setdefault('showlegend', True) - else: - kwargs.setdefault('showlegend', False) - kwargs.setdefault('name', 'Increasing') - kwargs.setdefault('line', dict(color=_DEFAULT_INCREASING_COLOR)) - - candle_incr_data = dict(type='box', - x=increase_x, - y=increase_y, - whiskerwidth=0, - boxpoints=False, - **kwargs) - - return [candle_incr_data] - - @staticmethod - def _make_decreasing_candle(open, high, low, close, dates, **kwargs): - """ - Makes boxplot trace for decreasing candlesticks - - :param (list) open: opening values - :param (list) high: high values - :param (list) low: low values - :param (list) close: closing values - :param (list) dates: list of datetime objects. Default: None - :param kwargs: kwargs to be passed to decreasing trace via - plotly.graph_objs.Scatter. - - :rtype (list) candle_decr_data: list of the box trace for - decreasing candlesticks. - """ - - decrease_x, decrease_y = _Candlestick( - open, high, low, close, dates, **kwargs).get_candle_decrease() - - if 'line' in kwargs: - kwargs.setdefault('fillcolor', kwargs['line']['color']) - else: - kwargs.setdefault('fillcolor', _DEFAULT_DECREASING_COLOR) - kwargs.setdefault('showlegend', False) - kwargs.setdefault('line', dict(color=_DEFAULT_DECREASING_COLOR)) - kwargs.setdefault('name', 'Decreasing') - - candle_decr_data = dict(type='box', - x=decrease_x, - y=decrease_y, - whiskerwidth=0, - boxpoints=False, - **kwargs) - - return [candle_decr_data] - - @staticmethod - def create_candlestick(open, high, low, close, - dates=None, direction='both', **kwargs): - """ - BETA function that creates a candlestick chart - - :param (list) open: opening values - :param (list) high: high values - :param (list) low: low values - :param (list) close: closing values - :param (list) dates: list of datetime objects. Default: None - :param (string) direction: direction can be 'increasing', 'decreasing', - or 'both'. When the direction is 'increasing', the returned figure - consists of all candlesticks where the close value is greater than - the corresponding open value, and when the direction is - 'decreasing', the returned figure consists of all candlesticks - where the close value is less than or equal to the corresponding - open value. When the direction is 'both', both increasing and - decreasing candlesticks are returned. Default: 'both' - :param kwargs: kwargs passed through plotly.graph_objs.Scatter. - These kwargs describe other attributes about the ohlc Scatter trace - such as the color or the legend name. For more information on valid - kwargs call help(plotly.graph_objs.Scatter) - - :rtype (dict): returns a representation of candlestick chart figure. - - Example 1: Simple candlestick chart from a Pandas DataFrame - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - from datetime import datetime - - import pandas.io.data as web - - df = web.DataReader("aapl", 'yahoo', datetime(2007, 10, 1), datetime(2009, 4, 1)) - fig = FF.create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index) - py.plot(fig, filename='finance/aapl-candlestick', validate=False) - ``` - - Example 2: Add text and annotations to the candlestick chart - ``` - fig = FF.create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index) - # Update the fig - all options here: https://plot.ly/python/reference/#Layout - fig['layout'].update({ - 'title': 'The Great Recession', - 'yaxis': {'title': 'AAPL Stock'}, - 'shapes': [{ - 'x0': '2007-12-01', 'x1': '2007-12-01', - 'y0': 0, 'y1': 1, 'xref': 'x', 'yref': 'paper', - 'line': {'color': 'rgb(30,30,30)', 'width': 1} - }], - 'annotations': [{ - 'x': '2007-12-01', 'y': 0.05, 'xref': 'x', 'yref': 'paper', - 'showarrow': False, 'xanchor': 'left', - 'text': 'Official start of the recession' - }] - }) - py.plot(fig, filename='finance/aapl-recession-candlestick', validate=False) - ``` - - Example 3: Customize the candlestick colors - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - from plotly.graph_objs import Line, Marker - from datetime import datetime - - import pandas.io.data as web - - df = web.DataReader("aapl", 'yahoo', datetime(2008, 1, 1), datetime(2009, 4, 1)) - - # Make increasing candlesticks and customize their color and name - fig_increasing = FF.create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index, - direction='increasing', name='AAPL', - marker=Marker(color='rgb(150, 200, 250)'), - line=Line(color='rgb(150, 200, 250)')) - - # Make decreasing candlesticks and customize their color and name - fig_decreasing = FF.create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index, - direction='decreasing', - marker=Marker(color='rgb(128, 128, 128)'), - line=Line(color='rgb(128, 128, 128)')) - - # Initialize the figure - fig = fig_increasing - - # Add decreasing data with .extend() - fig['data'].extend(fig_decreasing['data']) - - py.iplot(fig, filename='finance/aapl-candlestick-custom', validate=False) - ``` - - Example 4: Candlestick chart with datetime objects - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - - from datetime import datetime - - # Add data - open_data = [33.0, 33.3, 33.5, 33.0, 34.1] - high_data = [33.1, 33.3, 33.6, 33.2, 34.8] - low_data = [32.7, 32.7, 32.8, 32.6, 32.8] - close_data = [33.0, 32.9, 33.3, 33.1, 33.1] - dates = [datetime(year=2013, month=10, day=10), - datetime(year=2013, month=11, day=10), - datetime(year=2013, month=12, day=10), - datetime(year=2014, month=1, day=10), - datetime(year=2014, month=2, day=10)] - - # Create ohlc - fig = FF.create_candlestick(open_data, high_data, - low_data, close_data, dates=dates) - - py.iplot(fig, filename='finance/simple-candlestick', validate=False) - ``` - """ - # TODO: protected until #282 - from plotly.graph_objs import graph_objs - if dates is not None: - FigureFactory._validate_equal_length(open, high, low, close, dates) - else: - FigureFactory._validate_equal_length(open, high, low, close) - FigureFactory._validate_ohlc(open, high, low, close, direction, - **kwargs) - - if direction is 'increasing': - candle_incr_data = FigureFactory._make_increasing_candle( - open, high, low, close, dates, **kwargs) - data = candle_incr_data - elif direction is 'decreasing': - candle_decr_data = FigureFactory._make_decreasing_candle( - open, high, low, close, dates, **kwargs) - data = candle_decr_data - else: - candle_incr_data = FigureFactory._make_increasing_candle( - open, high, low, close, dates, **kwargs) - candle_decr_data = FigureFactory._make_decreasing_candle( - open, high, low, close, dates, **kwargs) - data = candle_incr_data + candle_decr_data - - layout = graph_objs.Layout() - return graph_objs.Figure(data=data, layout=layout) - - @staticmethod - def create_distplot(hist_data, group_labels, - bin_size=1., curve_type='kde', - colors=[], rug_text=[], histnorm=DEFAULT_HISTNORM, - show_hist=True, show_curve=True, - show_rug=True): - """ - BETA function that creates a distplot similar to seaborn.distplot - - The distplot can be composed of all or any combination of the following - 3 components: (1) histogram, (2) curve: (a) kernel density estimation - or (b) normal curve, and (3) rug plot. Additionally, multiple distplots - (from multiple datasets) can be created in the same plot. - - :param (list[list]) hist_data: Use list of lists to plot multiple data - sets on the same plot. - :param (list[str]) group_labels: Names for each data set. - :param (list[float]|float) bin_size: Size of histogram bins. - Default = 1. - :param (str) curve_type: 'kde' or 'normal'. Default = 'kde' - :param (str) histnorm: 'probability density' or 'probability' - Default = 'probability density' - :param (bool) show_hist: Add histogram to distplot? Default = True - :param (bool) show_curve: Add curve to distplot? Default = True - :param (bool) show_rug: Add rug to distplot? Default = True - :param (list[str]) colors: Colors for traces. - :param (list[list]) rug_text: Hovertext values for rug_plot, - :return (dict): Representation of a distplot figure. - - Example 1: Simple distplot of 1 data set - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - - hist_data = [[1.1, 1.1, 2.5, 3.0, 3.5, - 3.5, 4.1, 4.4, 4.5, 4.5, - 5.0, 5.0, 5.2, 5.5, 5.5, - 5.5, 5.5, 5.5, 6.1, 7.0]] - - group_labels = ['distplot example'] - - fig = FF.create_distplot(hist_data, group_labels) - - url = py.plot(fig, filename='Simple distplot', validate=False) - ``` - - Example 2: Two data sets and added rug text - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - - # Add histogram data - hist1_x = [0.8, 1.2, 0.2, 0.6, 1.6, - -0.9, -0.07, 1.95, 0.9, -0.2, - -0.5, 0.3, 0.4, -0.37, 0.6] - hist2_x = [0.8, 1.5, 1.5, 0.6, 0.59, - 1.0, 0.8, 1.7, 0.5, 0.8, - -0.3, 1.2, 0.56, 0.3, 2.2] - - # Group data together - hist_data = [hist1_x, hist2_x] - - group_labels = ['2012', '2013'] - - # Add text - rug_text_1 = ['a1', 'b1', 'c1', 'd1', 'e1', - 'f1', 'g1', 'h1', 'i1', 'j1', - 'k1', 'l1', 'm1', 'n1', 'o1'] - - rug_text_2 = ['a2', 'b2', 'c2', 'd2', 'e2', - 'f2', 'g2', 'h2', 'i2', 'j2', - 'k2', 'l2', 'm2', 'n2', 'o2'] - - # Group text together - rug_text_all = [rug_text_1, rug_text_2] - - # Create distplot - fig = FF.create_distplot( - hist_data, group_labels, rug_text=rug_text_all, bin_size=.2) - - # Add title - fig['layout'].update(title='Dist Plot') - - # Plot! - url = py.plot(fig, filename='Distplot with rug text', validate=False) - ``` - - Example 3: Plot with normal curve and hide rug plot - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - import numpy as np - - x1 = np.random.randn(190) - x2 = np.random.randn(200)+1 - x3 = np.random.randn(200)-1 - x4 = np.random.randn(210)+2 - - hist_data = [x1, x2, x3, x4] - group_labels = ['2012', '2013', '2014', '2015'] - - fig = FF.create_distplot( - hist_data, group_labels, curve_type='normal', - show_rug=False, bin_size=.4) - - url = py.plot(fig, filename='hist and normal curve', validate=False) - - Example 4: Distplot with Pandas - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - import numpy as np - import pandas as pd - - df = pd.DataFrame({'2012': np.random.randn(200), - '2013': np.random.randn(200)+1}) - py.iplot(FF.create_distplot([df[c] for c in df.columns], df.columns), - filename='examples/distplot with pandas', - validate=False) - ``` - """ - # TODO: protected until #282 - from plotly.graph_objs import graph_objs - FigureFactory._validate_distplot(hist_data, curve_type) - FigureFactory._validate_equal_length(hist_data, group_labels) - - if isinstance(bin_size, (float, int)): - bin_size = [bin_size]*len(hist_data) - - hist = _Distplot( - hist_data, histnorm, group_labels, bin_size, - curve_type, colors, rug_text, - show_hist, show_curve).make_hist() - - if curve_type == 'normal': - curve = _Distplot( - hist_data, histnorm, group_labels, bin_size, - curve_type, colors, rug_text, - show_hist, show_curve).make_normal() - else: - curve = _Distplot( - hist_data, histnorm, group_labels, bin_size, - curve_type, colors, rug_text, - show_hist, show_curve).make_kde() - - rug = _Distplot( - hist_data, histnorm, group_labels, bin_size, - curve_type, colors, rug_text, - show_hist, show_curve).make_rug() - - data = [] - if show_hist: - data.append(hist) - if show_curve: - data.append(curve) - if show_rug: - data.append(rug) - layout = graph_objs.Layout( - barmode='overlay', - hovermode='closest', - legend=dict(traceorder='reversed'), - xaxis1=dict(domain=[0.0, 1.0], - anchor='y2', - zeroline=False), - yaxis1=dict(domain=[0.35, 1], - anchor='free', - position=0.0), - yaxis2=dict(domain=[0, 0.25], - anchor='x1', - dtick=1, - showticklabels=False)) - else: - layout = graph_objs.Layout( - barmode='overlay', - hovermode='closest', - legend=dict(traceorder='reversed'), - xaxis1=dict(domain=[0.0, 1.0], - anchor='y2', - zeroline=False), - yaxis1=dict(domain=[0., 1], - anchor='free', - position=0.0)) - - data = sum(data, []) - return graph_objs.Figure(data=data, layout=layout) - - @staticmethod - def create_dendrogram(X, orientation="bottom", labels=None, - colorscale=None, distfun=None, - linkagefun=lambda x: sch.linkage(x, 'complete')): - """ - BETA function that returns a dendrogram Plotly figure object. - - :param (ndarray) X: Matrix of observations as array of arrays - :param (str) orientation: 'top', 'right', 'bottom', or 'left' - :param (list) labels: List of axis category labels(observation labels) - :param (list) colorscale: Optional colorscale for dendrogram tree - :param (function) distfun: Function to compute the pairwise distance from the observations - :param (function) linkagefun: Function to compute the linkage matrix from the pairwise distances - - clusters - - Example 1: Simple bottom oriented dendrogram - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - - import numpy as np - - X = np.random.rand(10,10) - dendro = FF.create_dendrogram(X) - plot_url = py.plot(dendro, filename='simple-dendrogram') - - ``` - - Example 2: Dendrogram to put on the left of the heatmap - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - - import numpy as np - - X = np.random.rand(5,5) - names = ['Jack', 'Oxana', 'John', 'Chelsea', 'Mark'] - dendro = FF.create_dendrogram(X, orientation='right', labels=names) - dendro['layout'].update({'width':700, 'height':500}) - - py.iplot(dendro, filename='vertical-dendrogram') - ``` - - Example 3: Dendrogram with Pandas - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - - import numpy as np - import pandas as pd - - Index= ['A','B','C','D','E','F','G','H','I','J'] - df = pd.DataFrame(abs(np.random.randn(10, 10)), index=Index) - fig = FF.create_dendrogram(df, labels=Index) - url = py.plot(fig, filename='pandas-dendrogram') - ``` - """ - dependencies = (_scipy_imported and _scipy__spatial_imported and - _scipy__cluster__hierarchy_imported) - - if dependencies is False: - raise ImportError("FigureFactory.create_dendrogram requires scipy, \ - scipy.spatial and scipy.hierarchy") - - s = X.shape - if len(s) != 2: - exceptions.PlotlyError("X should be 2-dimensional array.") - - if distfun is None: - distfun = scs.distance.pdist - - dendrogram = _Dendrogram(X, orientation, labels, colorscale, - distfun=distfun, linkagefun=linkagefun) - - return {'layout': dendrogram.layout, - 'data': dendrogram.data} - - @staticmethod - def create_annotated_heatmap(z, x=None, y=None, annotation_text=None, - colorscale='RdBu', font_colors=None, - showscale=False, reversescale=False, - **kwargs): - """ - BETA function that creates annotated heatmaps - - This function adds annotations to each cell of the heatmap. - - :param (list[list]|ndarray) z: z matrix to create heatmap. - :param (list) x: x axis labels. - :param (list) y: y axis labels. - :param (list[list]|ndarray) annotation_text: Text strings for - annotations. Should have the same dimensions as the z matrix. If no - text is added, the values of the z matrix are annotated. Default = - z matrix values. - :param (list|str) colorscale: heatmap colorscale. - :param (list) font_colors: List of two color strings: [min_text_color, - max_text_color] where min_text_color is applied to annotations for - heatmap values < (max_value - min_value)/2. If font_colors is not - defined, the colors are defined logically as black or white - depending on the heatmap's colorscale. - :param (bool) showscale: Display colorscale. Default = False - :param kwargs: kwargs passed through plotly.graph_objs.Heatmap. - These kwargs describe other attributes about the annotated Heatmap - trace such as the colorscale. For more information on valid kwargs - call help(plotly.graph_objs.Heatmap) - - Example 1: Simple annotated heatmap with default configuration - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - - z = [[0.300000, 0.00000, 0.65, 0.300000], - [1, 0.100005, 0.45, 0.4300], - [0.300000, 0.00000, 0.65, 0.300000], - [1, 0.100005, 0.45, 0.00000]] - - figure = FF.create_annotated_heatmap(z) - py.iplot(figure) - ``` - """ - # TODO: protected until #282 - from plotly.graph_objs import graph_objs - - # Avoiding mutables in the call signature - font_colors = font_colors if font_colors is not None else [] - FigureFactory._validate_annotated_heatmap(z, x, y, annotation_text) - annotations = _AnnotatedHeatmap(z, x, y, annotation_text, - colorscale, font_colors, reversescale, - **kwargs).make_annotations() - - if x or y: - trace = dict(type='heatmap', z=z, x=x, y=y, colorscale=colorscale, - showscale=showscale, **kwargs) - layout = dict(annotations=annotations, - xaxis=dict(ticks='', dtick=1, side='top', - gridcolor='rgb(0, 0, 0)'), - yaxis=dict(ticks='', dtick=1, ticksuffix=' ')) - else: - trace = dict(type='heatmap', z=z, colorscale=colorscale, - showscale=showscale, **kwargs) - layout = dict(annotations=annotations, - xaxis=dict(ticks='', side='top', - gridcolor='rgb(0, 0, 0)', - showticklabels=False), - yaxis=dict(ticks='', ticksuffix=' ', - showticklabels=False)) - - data = [trace] - - return graph_objs.Figure(data=data, layout=layout) - - @staticmethod - def create_table(table_text, colorscale=None, font_colors=None, - index=False, index_title='', annotation_offset=.45, - height_constant=30, hoverinfo='none', **kwargs): - """ - BETA function that creates data tables - - :param (pandas.Dataframe | list[list]) text: data for table. - :param (str|list[list]) colorscale: Colorscale for table where the - color at value 0 is the header color, .5 is the first table color - and 1 is the second table color. (Set .5 and 1 to avoid the striped - table effect). Default=[[0, '#66b2ff'], [.5, '#d9d9d9'], - [1, '#ffffff']] - :param (list) font_colors: Color for fonts in table. Can be a single - color, three colors, or a color for each row in the table. - Default=['#000000'] (black text for the entire table) - :param (int) height_constant: Constant multiplied by # of rows to - create table height. Default=30. - :param (bool) index: Create (header-colored) index column index from - Pandas dataframe or list[0] for each list in text. Default=False. - :param (string) index_title: Title for index column. Default=''. - :param kwargs: kwargs passed through plotly.graph_objs.Heatmap. - These kwargs describe other attributes about the annotated Heatmap - trace such as the colorscale. For more information on valid kwargs - call help(plotly.graph_objs.Heatmap) - - Example 1: Simple Plotly Table - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - - text = [['Country', 'Year', 'Population'], - ['US', 2000, 282200000], - ['Canada', 2000, 27790000], - ['US', 2010, 309000000], - ['Canada', 2010, 34000000]] - - table = FF.create_table(text) - py.iplot(table) - ``` - - Example 2: Table with Custom Coloring - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - - text = [['Country', 'Year', 'Population'], - ['US', 2000, 282200000], - ['Canada', 2000, 27790000], - ['US', 2010, 309000000], - ['Canada', 2010, 34000000]] - - table = FF.create_table(text, - colorscale=[[0, '#000000'], - [.5, '#80beff'], - [1, '#cce5ff']], - font_colors=['#ffffff', '#000000', - '#000000']) - py.iplot(table) - ``` - Example 3: Simple Plotly Table with Pandas - ``` - import plotly.plotly as py - from plotly.tools import FigureFactory as FF - - import pandas as pd - - df = pd.read_csv('http://www.stat.ubc.ca/~jenny/notOcto/STAT545A/examples/gapminder/data/gapminderDataFiveYear.txt', sep='\t') - df_p = df[0:25] - - table_simple = FF.create_table(df_p) - py.iplot(table_simple) - ``` - """ - # TODO: protected until #282 - from plotly.graph_objs import graph_objs - - # Avoiding mutables in the call signature - colorscale = \ - colorscale if colorscale is not None else [[0, '#00083e'], - [.5, '#ededee'], - [1, '#ffffff']] - font_colors = font_colors if font_colors is not None else ['#ffffff', - '#000000', - '#000000'] - - FigureFactory._validate_table(table_text, font_colors) - table_matrix = _Table(table_text, colorscale, font_colors, index, - index_title, annotation_offset, - **kwargs).get_table_matrix() - annotations = _Table(table_text, colorscale, font_colors, index, - index_title, annotation_offset, - **kwargs).make_table_annotations() - - trace = dict(type='heatmap', z=table_matrix, opacity=.75, - colorscale=colorscale, showscale=False, - hoverinfo=hoverinfo, **kwargs) - - data = [trace] - layout = dict(annotations=annotations, - height=len(table_matrix)*height_constant + 50, - margin=dict(t=0, b=0, r=0, l=0), - yaxis=dict(autorange='reversed', zeroline=False, - gridwidth=2, ticks='', dtick=1, tick0=.5, - showticklabels=False), - xaxis=dict(zeroline=False, gridwidth=2, ticks='', - dtick=1, tick0=-0.5, showticklabels=False)) - return graph_objs.Figure(data=data, layout=layout) - - -class _Quiver(FigureFactory): - """ - Refer to FigureFactory.create_quiver() for docstring - """ - def __init__(self, x, y, u, v, - scale, arrow_scale, angle, **kwargs): - try: - x = FigureFactory._flatten(x) - except exceptions.PlotlyError: - pass - - try: - y = FigureFactory._flatten(y) - except exceptions.PlotlyError: - pass - - try: - u = FigureFactory._flatten(u) - except exceptions.PlotlyError: - pass - - try: - v = FigureFactory._flatten(v) - except exceptions.PlotlyError: - pass - - self.x = x - self.y = y - self.u = u - self.v = v - self.scale = scale - self.arrow_scale = arrow_scale - self.angle = angle - self.end_x = [] - self.end_y = [] - self.scale_uv() - barb_x, barb_y = self.get_barbs() - arrow_x, arrow_y = self.get_quiver_arrows() - - def scale_uv(self): - """ - Scales u and v to avoid overlap of the arrows. - - u and v are added to x and y to get the - endpoints of the arrows so a smaller scale value will - result in less overlap of arrows. - """ - self.u = [i * self.scale for i in self.u] - self.v = [i * self.scale for i in self.v] - - def get_barbs(self): - """ - Creates x and y startpoint and endpoint pairs - - After finding the endpoint of each barb this zips startpoint and - endpoint pairs to create 2 lists: x_values for barbs and y values - for barbs - - :rtype: (list, list) barb_x, barb_y: list of startpoint and endpoint - x_value pairs separated by a None to create the barb of the arrow, - and list of startpoint and endpoint y_value pairs separated by a - None to create the barb of the arrow. - """ - self.end_x = [i + j for i, j in zip(self.x, self.u)] - self.end_y = [i + j for i, j in zip(self.y, self.v)] - empty = [None] * len(self.x) - barb_x = FigureFactory._flatten(zip(self.x, self.end_x, empty)) - barb_y = FigureFactory._flatten(zip(self.y, self.end_y, empty)) - return barb_x, barb_y - - def get_quiver_arrows(self): - """ - Creates lists of x and y values to plot the arrows - - Gets length of each barb then calculates the length of each side of - the arrow. Gets angle of barb and applies angle to each side of the - arrowhead. Next uses arrow_scale to scale the length of arrowhead and - creates x and y values for arrowhead point1 and point2. Finally x and y - values for point1, endpoint and point2s for each arrowhead are - separated by a None and zipped to create lists of x and y values for - the arrows. - - :rtype: (list, list) arrow_x, arrow_y: list of point1, endpoint, point2 - x_values separated by a None to create the arrowhead and list of - point1, endpoint, point2 y_values separated by a None to create - the barb of the arrow. - """ - dif_x = [i - j for i, j in zip(self.end_x, self.x)] - dif_y = [i - j for i, j in zip(self.end_y, self.y)] - - # Get barb lengths(default arrow length = 30% barb length) - barb_len = [None] * len(self.x) - for index in range(len(barb_len)): - barb_len[index] = math.hypot(dif_x[index], dif_y[index]) - - # Make arrow lengths - arrow_len = [None] * len(self.x) - arrow_len = [i * self.arrow_scale for i in barb_len] - - # Get barb angles - barb_ang = [None] * len(self.x) - for index in range(len(barb_ang)): - barb_ang[index] = math.atan2(dif_y[index], dif_x[index]) - - # Set angles to create arrow - ang1 = [i + self.angle for i in barb_ang] - ang2 = [i - self.angle for i in barb_ang] - - cos_ang1 = [None] * len(ang1) - for index in range(len(ang1)): - cos_ang1[index] = math.cos(ang1[index]) - seg1_x = [i * j for i, j in zip(arrow_len, cos_ang1)] - - sin_ang1 = [None] * len(ang1) - for index in range(len(ang1)): - sin_ang1[index] = math.sin(ang1[index]) - seg1_y = [i * j for i, j in zip(arrow_len, sin_ang1)] - - cos_ang2 = [None] * len(ang2) - for index in range(len(ang2)): - cos_ang2[index] = math.cos(ang2[index]) - seg2_x = [i * j for i, j in zip(arrow_len, cos_ang2)] - - sin_ang2 = [None] * len(ang2) - for index in range(len(ang2)): - sin_ang2[index] = math.sin(ang2[index]) - seg2_y = [i * j for i, j in zip(arrow_len, sin_ang2)] - - # Set coordinates to create arrow - for index in range(len(self.end_x)): - point1_x = [i - j for i, j in zip(self.end_x, seg1_x)] - point1_y = [i - j for i, j in zip(self.end_y, seg1_y)] - point2_x = [i - j for i, j in zip(self.end_x, seg2_x)] - point2_y = [i - j for i, j in zip(self.end_y, seg2_y)] - - # Combine lists to create arrow - empty = [None] * len(self.end_x) - arrow_x = FigureFactory._flatten(zip(point1_x, self.end_x, - point2_x, empty)) - arrow_y = FigureFactory._flatten(zip(point1_y, self.end_y, - point2_y, empty)) - return arrow_x, arrow_y - - -class _Streamline(FigureFactory): - """ - Refer to FigureFactory.create_streamline() for docstring - """ - def __init__(self, x, y, u, v, - density, angle, - arrow_scale, **kwargs): - self.x = np.array(x) - self.y = np.array(y) - self.u = np.array(u) - self.v = np.array(v) - self.angle = angle - self.arrow_scale = arrow_scale - self.density = int(30 * density) # Scale similarly to other functions - self.delta_x = self.x[1] - self.x[0] - self.delta_y = self.y[1] - self.y[0] - self.val_x = self.x - self.val_y = self.y - - # Set up spacing - self.blank = np.zeros((self.density, self.density)) - self.spacing_x = len(self.x) / float(self.density - 1) - self.spacing_y = len(self.y) / float(self.density - 1) - self.trajectories = [] - - # Rescale speed onto axes-coordinates - self.u = self.u / (self.x[-1] - self.x[0]) - self.v = self.v / (self.y[-1] - self.y[0]) - self.speed = np.sqrt(self.u ** 2 + self.v ** 2) - - # Rescale u and v for integrations. - self.u *= len(self.x) - self.v *= len(self.y) - self.st_x = [] - self.st_y = [] - self.get_streamlines() - streamline_x, streamline_y = self.sum_streamlines() - arrows_x, arrows_y = self.get_streamline_arrows() - - def blank_pos(self, xi, yi): - """ - Set up positions for trajectories to be used with rk4 function. - """ - return (int((xi / self.spacing_x) + 0.5), - int((yi / self.spacing_y) + 0.5)) - - def value_at(self, a, xi, yi): - """ - Set up for RK4 function, based on Bokeh's streamline code - """ - if isinstance(xi, np.ndarray): - self.x = xi.astype(np.int) - self.y = yi.astype(np.int) - else: - self.val_x = np.int(xi) - self.val_y = np.int(yi) - a00 = a[self.val_y, self.val_x] - a01 = a[self.val_y, self.val_x + 1] - a10 = a[self.val_y + 1, self.val_x] - a11 = a[self.val_y + 1, self.val_x + 1] - xt = xi - self.val_x - yt = yi - self.val_y - a0 = a00 * (1 - xt) + a01 * xt - a1 = a10 * (1 - xt) + a11 * xt - return a0 * (1 - yt) + a1 * yt - - def rk4_integrate(self, x0, y0): - """ - RK4 forward and back trajectories from the initial conditions. - - Adapted from Bokeh's streamline -uses Runge-Kutta method to fill - x and y trajectories then checks length of traj (s in units of axes) - """ - def f(xi, yi): - dt_ds = 1. / self.value_at(self.speed, xi, yi) - ui = self.value_at(self.u, xi, yi) - vi = self.value_at(self.v, xi, yi) - return ui * dt_ds, vi * dt_ds - - def g(xi, yi): - dt_ds = 1. / self.value_at(self.speed, xi, yi) - ui = self.value_at(self.u, xi, yi) - vi = self.value_at(self.v, xi, yi) - return -ui * dt_ds, -vi * dt_ds - - check = lambda xi, yi: (0 <= xi < len(self.x) - 1 and - 0 <= yi < len(self.y) - 1) - xb_changes = [] - yb_changes = [] - - def rk4(x0, y0, f): - ds = 0.01 - stotal = 0 - xi = x0 - yi = y0 - xb, yb = self.blank_pos(xi, yi) - xf_traj = [] - yf_traj = [] - while check(xi, yi): - xf_traj.append(xi) - yf_traj.append(yi) - try: - k1x, k1y = f(xi, yi) - k2x, k2y = f(xi + .5 * ds * k1x, yi + .5 * ds * k1y) - k3x, k3y = f(xi + .5 * ds * k2x, yi + .5 * ds * k2y) - k4x, k4y = f(xi + ds * k3x, yi + ds * k3y) - except IndexError: - break - xi += ds * (k1x + 2 * k2x + 2 * k3x + k4x) / 6. - yi += ds * (k1y + 2 * k2y + 2 * k3y + k4y) / 6. - if not check(xi, yi): - break - stotal += ds - new_xb, new_yb = self.blank_pos(xi, yi) - if new_xb != xb or new_yb != yb: - if self.blank[new_yb, new_xb] == 0: - self.blank[new_yb, new_xb] = 1 - xb_changes.append(new_xb) - yb_changes.append(new_yb) - xb = new_xb - yb = new_yb - else: - break - if stotal > 2: - break - return stotal, xf_traj, yf_traj - - sf, xf_traj, yf_traj = rk4(x0, y0, f) - sb, xb_traj, yb_traj = rk4(x0, y0, g) - stotal = sf + sb - x_traj = xb_traj[::-1] + xf_traj[1:] - y_traj = yb_traj[::-1] + yf_traj[1:] - - if len(x_traj) < 1: - return None - if stotal > .2: - initxb, inityb = self.blank_pos(x0, y0) - self.blank[inityb, initxb] = 1 - return x_traj, y_traj - else: - for xb, yb in zip(xb_changes, yb_changes): - self.blank[yb, xb] = 0 - return None - - def traj(self, xb, yb): - """ - Integrate trajectories - - :param (int) xb: results of passing xi through self.blank_pos - :param (int) xy: results of passing yi through self.blank_pos - - Calculate each trajectory based on rk4 integrate method. - """ - - if xb < 0 or xb >= self.density or yb < 0 or yb >= self.density: - return - if self.blank[yb, xb] == 0: - t = self.rk4_integrate(xb * self.spacing_x, yb * self.spacing_y) - if t is not None: - self.trajectories.append(t) - - def get_streamlines(self): - """ - Get streamlines by building trajectory set. - """ - for indent in range(self.density // 2): - for xi in range(self.density - 2 * indent): - self.traj(xi + indent, indent) - self.traj(xi + indent, self.density - 1 - indent) - self.traj(indent, xi + indent) - self.traj(self.density - 1 - indent, xi + indent) - - self.st_x = [np.array(t[0]) * self.delta_x + self.x[0] for t in - self.trajectories] - self.st_y = [np.array(t[1]) * self.delta_y + self.y[0] for t in - self.trajectories] - - for index in range(len(self.st_x)): - self.st_x[index] = self.st_x[index].tolist() - self.st_x[index].append(np.nan) - - for index in range(len(self.st_y)): - self.st_y[index] = self.st_y[index].tolist() - self.st_y[index].append(np.nan) - - def get_streamline_arrows(self): - """ - Makes an arrow for each streamline. - - Gets angle of streamline at 1/3 mark and creates arrow coordinates - based off of user defined angle and arrow_scale. - - :param (array) st_x: x-values for all streamlines - :param (array) st_y: y-values for all streamlines - :param (angle in radians) angle: angle of arrowhead. Default = pi/9 - :param (float in [0,1]) arrow_scale: value to scale length of arrowhead - Default = .09 - :rtype (list, list) arrows_x: x-values to create arrowhead and - arrows_y: y-values to create arrowhead - """ - arrow_end_x = np.empty((len(self.st_x))) - arrow_end_y = np.empty((len(self.st_y))) - arrow_start_x = np.empty((len(self.st_x))) - arrow_start_y = np.empty((len(self.st_y))) - for index in range(len(self.st_x)): - arrow_end_x[index] = (self.st_x[index] - [int(len(self.st_x[index]) / 3)]) - arrow_start_x[index] = (self.st_x[index] - [(int(len(self.st_x[index]) / 3)) - 1]) - arrow_end_y[index] = (self.st_y[index] - [int(len(self.st_y[index]) / 3)]) - arrow_start_y[index] = (self.st_y[index] - [(int(len(self.st_y[index]) / 3)) - 1]) - - dif_x = arrow_end_x - arrow_start_x - dif_y = arrow_end_y - arrow_start_y - - streamline_ang = np.arctan(dif_y / dif_x) - - ang1 = streamline_ang + (self.angle) - ang2 = streamline_ang - (self.angle) - - seg1_x = np.cos(ang1) * self.arrow_scale - seg1_y = np.sin(ang1) * self.arrow_scale - seg2_x = np.cos(ang2) * self.arrow_scale - seg2_y = np.sin(ang2) * self.arrow_scale - - point1_x = np.empty((len(dif_x))) - point1_y = np.empty((len(dif_y))) - point2_x = np.empty((len(dif_x))) - point2_y = np.empty((len(dif_y))) - - for index in range(len(dif_x)): - if dif_x[index] >= 0: - point1_x[index] = arrow_end_x[index] - seg1_x[index] - point1_y[index] = arrow_end_y[index] - seg1_y[index] - point2_x[index] = arrow_end_x[index] - seg2_x[index] - point2_y[index] = arrow_end_y[index] - seg2_y[index] - else: - point1_x[index] = arrow_end_x[index] + seg1_x[index] - point1_y[index] = arrow_end_y[index] + seg1_y[index] - point2_x[index] = arrow_end_x[index] + seg2_x[index] - point2_y[index] = arrow_end_y[index] + seg2_y[index] - - space = np.empty((len(point1_x))) - space[:] = np.nan - - # Combine arrays into matrix - arrows_x = np.matrix([point1_x, arrow_end_x, point2_x, space]) - arrows_x = np.array(arrows_x) - arrows_x = arrows_x.flatten('F') - arrows_x = arrows_x.tolist() - - # Combine arrays into matrix - arrows_y = np.matrix([point1_y, arrow_end_y, point2_y, space]) - arrows_y = np.array(arrows_y) - arrows_y = arrows_y.flatten('F') - arrows_y = arrows_y.tolist() - - return arrows_x, arrows_y - - def sum_streamlines(self): - """ - Makes all streamlines readable as a single trace. - - :rtype (list, list): streamline_x: all x values for each streamline - combined into single list and streamline_y: all y values for each - streamline combined into single list - """ - streamline_x = sum(self.st_x, []) - streamline_y = sum(self.st_y, []) - return streamline_x, streamline_y - - -class _OHLC(FigureFactory): - """ - Refer to FigureFactory.create_ohlc_increase() for docstring. - """ - def __init__(self, open, high, low, close, dates, **kwargs): - self.open = open - self.high = high - self.low = low - self.close = close - self.empty = [None] * len(open) - self.dates = dates - - self.all_x = [] - self.all_y = [] - self.increase_x = [] - self.increase_y = [] - self.decrease_x = [] - self.decrease_y = [] - self.get_all_xy() - self.separate_increase_decrease() - - def get_all_xy(self): - """ - Zip data to create OHLC shape - - OHLC shape: low to high vertical bar with - horizontal branches for open and close values. - If dates were added, the smallest date difference is calculated and - multiplied by .2 to get the length of the open and close branches. - If no date data was provided, the x-axis is a list of integers and the - length of the open and close branches is .2. - """ - self.all_y = list(zip(self.open, self.open, self.high, - self.low, self.close, self.close, self.empty)) - if self.dates is not None: - date_dif = [] - for i in range(len(self.dates) - 1): - date_dif.append(self.dates[i + 1] - self.dates[i]) - date_dif_min = (min(date_dif)) / 5 - self.all_x = [[x - date_dif_min, x, x, x, x, x + - date_dif_min, None] for x in self.dates] - else: - self.all_x = [[x - .2, x, x, x, x, x + .2, None] - for x in range(len(self.open))] - - def separate_increase_decrease(self): - """ - Separate data into two groups: increase and decrease - - (1) Increase, where close > open and - (2) Decrease, where close <= open - """ - for index in range(len(self.open)): - if self.close[index] is None: - pass - elif self.close[index] > self.open[index]: - self.increase_x.append(self.all_x[index]) - self.increase_y.append(self.all_y[index]) - else: - self.decrease_x.append(self.all_x[index]) - self.decrease_y.append(self.all_y[index]) - - def get_increase(self): - """ - Flatten increase data and get increase text - - :rtype (list, list, list): flat_increase_x: x-values for the increasing - trace, flat_increase_y: y=values for the increasing trace and - text_increase: hovertext for the increasing trace - """ - flat_increase_x = FigureFactory._flatten(self.increase_x) - flat_increase_y = FigureFactory._flatten(self.increase_y) - text_increase = (("Open", "Open", "High", - "Low", "Close", "Close", '') - * (len(self.increase_x))) - - return flat_increase_x, flat_increase_y, text_increase - - def get_decrease(self): - """ - Flatten decrease data and get decrease text - - :rtype (list, list, list): flat_decrease_x: x-values for the decreasing - trace, flat_decrease_y: y=values for the decreasing trace and - text_decrease: hovertext for the decreasing trace - """ - flat_decrease_x = FigureFactory._flatten(self.decrease_x) - flat_decrease_y = FigureFactory._flatten(self.decrease_y) - text_decrease = (("Open", "Open", "High", - "Low", "Close", "Close", '') - * (len(self.decrease_x))) - - return flat_decrease_x, flat_decrease_y, text_decrease - - -class _Candlestick(FigureFactory): - """ - Refer to FigureFactory.create_candlestick() for docstring. - """ - def __init__(self, open, high, low, close, dates, **kwargs): - self.open = open - self.high = high - self.low = low - self.close = close - if dates is not None: - self.x = dates - else: - self.x = [x for x in range(len(self.open))] - self.get_candle_increase() - - def get_candle_increase(self): - """ - Separate increasing data from decreasing data. - - The data is increasing when close value > open value - and decreasing when the close value <= open value. - """ - increase_y = [] - increase_x = [] - for index in range(len(self.open)): - if self.close[index] > self.open[index]: - increase_y.append(self.low[index]) - increase_y.append(self.open[index]) - increase_y.append(self.close[index]) - increase_y.append(self.close[index]) - increase_y.append(self.close[index]) - increase_y.append(self.high[index]) - increase_x.append(self.x[index]) - - increase_x = [[x, x, x, x, x, x] for x in increase_x] - increase_x = FigureFactory._flatten(increase_x) - - return increase_x, increase_y - - def get_candle_decrease(self): - """ - Separate increasing data from decreasing data. - - The data is increasing when close value > open value - and decreasing when the close value <= open value. - """ - decrease_y = [] - decrease_x = [] - for index in range(len(self.open)): - if self.close[index] <= self.open[index]: - decrease_y.append(self.low[index]) - decrease_y.append(self.open[index]) - decrease_y.append(self.close[index]) - decrease_y.append(self.close[index]) - decrease_y.append(self.close[index]) - decrease_y.append(self.high[index]) - decrease_x.append(self.x[index]) - - decrease_x = [[x, x, x, x, x, x] for x in decrease_x] - decrease_x = FigureFactory._flatten(decrease_x) - - return decrease_x, decrease_y - - -class _Distplot(FigureFactory): - """ - Refer to TraceFactory.create_distplot() for docstring - """ - def __init__(self, hist_data, histnorm, group_labels, - bin_size, curve_type, colors, - rug_text, show_hist, show_curve): - self.hist_data = hist_data - self.histnorm = histnorm - self.group_labels = group_labels - self.bin_size = bin_size - self.show_hist = show_hist - self.show_curve = show_curve - self.trace_number = len(hist_data) - if rug_text: - self.rug_text = rug_text - else: - self.rug_text = [None] * self.trace_number - - self.start = [] - self.end = [] - if colors: - self.colors = colors - else: - self.colors = [ - "rgb(31, 119, 180)", "rgb(255, 127, 14)", - "rgb(44, 160, 44)", "rgb(214, 39, 40)", - "rgb(148, 103, 189)", "rgb(140, 86, 75)", - "rgb(227, 119, 194)", "rgb(127, 127, 127)", - "rgb(188, 189, 34)", "rgb(23, 190, 207)"] - self.curve_x = [None] * self.trace_number - self.curve_y = [None] * self.trace_number - - for trace in self.hist_data: - self.start.append(min(trace) * 1.) - self.end.append(max(trace) * 1.) - - def make_hist(self): - """ - Makes the histogram(s) for FigureFactory.create_distplot(). - - :rtype (list) hist: list of histogram representations - """ - hist = [None] * self.trace_number - - for index in range(self.trace_number): - hist[index] = dict(type='histogram', - x=self.hist_data[index], - xaxis='x1', - yaxis='y1', - histnorm=self.histnorm, - name=self.group_labels[index], - legendgroup=self.group_labels[index], - marker=dict(color=self.colors[index]), - autobinx=False, - xbins=dict(start=self.start[index], - end=self.end[index], - size=self.bin_size[index]), - opacity=.7) - return hist - - def make_kde(self): - """ - Makes the kernel density estimation(s) for create_distplot(). - - This is called when curve_type = 'kde' in create_distplot(). - - :rtype (list) curve: list of kde representations - """ - curve = [None] * self.trace_number - for index in range(self.trace_number): - self.curve_x[index] = [self.start[index] + - x * (self.end[index] - self.start[index]) - / 500 for x in range(500)] - self.curve_y[index] = (scipy.stats.gaussian_kde - (self.hist_data[index]) - (self.curve_x[index])) - - if self.histnorm == ALTERNATIVE_HISTNORM: - self.curve_y[index] *= self.bin_size[index] - - for index in range(self.trace_number): - curve[index] = dict(type='scatter', - x=self.curve_x[index], - y=self.curve_y[index], - xaxis='x1', - yaxis='y1', - mode='lines', - name=self.group_labels[index], - legendgroup=self.group_labels[index], - showlegend=False if self.show_hist else True, - marker=dict(color=self.colors[index])) - return curve - - def make_normal(self): - """ - Makes the normal curve(s) for create_distplot(). - - This is called when curve_type = 'normal' in create_distplot(). - - :rtype (list) curve: list of normal curve representations - """ - curve = [None] * self.trace_number - mean = [None] * self.trace_number - sd = [None] * self.trace_number - - for index in range(self.trace_number): - mean[index], sd[index] = (scipy.stats.norm.fit - (self.hist_data[index])) - self.curve_x[index] = [self.start[index] + - x * (self.end[index] - self.start[index]) - / 500 for x in range(500)] - self.curve_y[index] = scipy.stats.norm.pdf( - self.curve_x[index], loc=mean[index], scale=sd[index]) - - if self.histnorm == ALTERNATIVE_HISTNORM: - self.curve_y[index] *= self.bin_size[index] - - for index in range(self.trace_number): - curve[index] = dict(type='scatter', - x=self.curve_x[index], - y=self.curve_y[index], - xaxis='x1', - yaxis='y1', - mode='lines', - name=self.group_labels[index], - legendgroup=self.group_labels[index], - showlegend=False if self.show_hist else True, - marker=dict(color=self.colors[index])) - return curve - - def make_rug(self): - """ - Makes the rug plot(s) for create_distplot(). - - :rtype (list) rug: list of rug plot representations - """ - rug = [None] * self.trace_number - for index in range(self.trace_number): - - rug[index] = dict(type='scatter', - x=self.hist_data[index], - y=([self.group_labels[index]] * - len(self.hist_data[index])), - xaxis='x1', - yaxis='y2', - mode='markers', - name=self.group_labels[index], - legendgroup=self.group_labels[index], - showlegend=(False if self.show_hist or - self.show_curve else True), - text=self.rug_text[index], - marker=dict(color=self.colors[index], - symbol='line-ns-open')) - return rug - - -class _Dendrogram(FigureFactory): - """Refer to FigureFactory.create_dendrogram() for docstring.""" - - def __init__(self, X, orientation='bottom', labels=None, colorscale=None, - width="100%", height="100%", xaxis='xaxis', yaxis='yaxis', - distfun=None, linkagefun=lambda x: sch.linkage(x, 'complete')): - # TODO: protected until #282 - from plotly.graph_objs import graph_objs - self.orientation = orientation - self.labels = labels - self.xaxis = xaxis - self.yaxis = yaxis - self.data = [] - self.leaves = [] - self.sign = {self.xaxis: 1, self.yaxis: 1} - self.layout = {self.xaxis: {}, self.yaxis: {}} - - if self.orientation in ['left', 'bottom']: - self.sign[self.xaxis] = 1 - else: - self.sign[self.xaxis] = -1 - - if self.orientation in ['right', 'bottom']: - self.sign[self.yaxis] = 1 - else: - self.sign[self.yaxis] = -1 - - if distfun is None: - distfun = scs.distance.pdist - - (dd_traces, xvals, yvals, - ordered_labels, leaves) = self.get_dendrogram_traces(X, colorscale, distfun, linkagefun) - - self.labels = ordered_labels - self.leaves = leaves - yvals_flat = yvals.flatten() - xvals_flat = xvals.flatten() - - self.zero_vals = [] - - for i in range(len(yvals_flat)): - if yvals_flat[i] == 0.0 and xvals_flat[i] not in self.zero_vals: - self.zero_vals.append(xvals_flat[i]) - - self.zero_vals.sort() - - self.layout = self.set_figure_layout(width, height) - self.data = graph_objs.Data(dd_traces) - - def get_color_dict(self, colorscale): - """ - Returns colorscale used for dendrogram tree clusters. - - :param (list) colorscale: Colors to use for the plot in rgb format. - :rtype (dict): A dict of default colors mapped to the user colorscale. - - """ - - # These are the color codes returned for dendrograms - # We're replacing them with nicer colors - d = {'r': 'red', - 'g': 'green', - 'b': 'blue', - 'c': 'cyan', - 'm': 'magenta', - 'y': 'yellow', - 'k': 'black', - 'w': 'white'} - default_colors = OrderedDict(sorted(d.items(), key=lambda t: t[0])) - - if colorscale is None: - colorscale = [ - 'rgb(0,116,217)', # blue - 'rgb(35,205,205)', # cyan - 'rgb(61,153,112)', # green - 'rgb(40,35,35)', # black - 'rgb(133,20,75)', # magenta - 'rgb(255,65,54)', # red - 'rgb(255,255,255)', # white - 'rgb(255,220,0)'] # yellow - - for i in range(len(default_colors.keys())): - k = list(default_colors.keys())[i] # PY3 won't index keys - if i < len(colorscale): - default_colors[k] = colorscale[i] - - return default_colors - - def set_axis_layout(self, axis_key): - """ - Sets and returns default axis object for dendrogram figure. - - :param (str) axis_key: E.g., 'xaxis', 'xaxis1', 'yaxis', yaxis1', etc. - :rtype (dict): An axis_key dictionary with set parameters. - - """ - axis_defaults = { - 'type': 'linear', - 'ticks': 'outside', - 'mirror': 'allticks', - 'rangemode': 'tozero', - 'showticklabels': True, - 'zeroline': False, - 'showgrid': False, - 'showline': True, - } - - if len(self.labels) != 0: - axis_key_labels = self.xaxis - if self.orientation in ['left', 'right']: - axis_key_labels = self.yaxis - if axis_key_labels not in self.layout: - self.layout[axis_key_labels] = {} - self.layout[axis_key_labels]['tickvals'] = \ - [zv*self.sign[axis_key] for zv in self.zero_vals] - self.layout[axis_key_labels]['ticktext'] = self.labels - self.layout[axis_key_labels]['tickmode'] = 'array' - - self.layout[axis_key].update(axis_defaults) - - return self.layout[axis_key] - - def set_figure_layout(self, width, height): - """ - Sets and returns default layout object for dendrogram figure. - - """ - self.layout.update({ - 'showlegend': False, - 'autosize': False, - 'hovermode': 'closest', - 'width': width, - 'height': height - }) - - self.set_axis_layout(self.xaxis) - self.set_axis_layout(self.yaxis) - - return self.layout - - def get_dendrogram_traces(self, X, colorscale, distfun, linkagefun): - """ - Calculates all the elements needed for plotting a dendrogram. - - :param (ndarray) X: Matrix of observations as array of arrays - :param (list) colorscale: Color scale for dendrogram tree clusters - :param (function) distfun: Function to compute the pairwise distance from the observations - :param (function) linkagefun: Function to compute the linkage matrix from the pairwise distances - :rtype (tuple): Contains all the traces in the following order: - (a) trace_list: List of Plotly trace objects for dendrogram tree - (b) icoord: All X points of the dendrogram tree as array of arrays - with length 4 - (c) dcoord: All Y points of the dendrogram tree as array of arrays - with length 4 - (d) ordered_labels: leaf labels in the order they are going to - appear on the plot - (e) P['leaves']: left-to-right traversal of the leaves - - """ - # TODO: protected until #282 - from plotly.graph_objs import graph_objs - d = distfun(X) - Z = linkagefun(d) - P = sch.dendrogram(Z, orientation=self.orientation, - labels=self.labels, no_plot=True) - - icoord = scp.array(P['icoord']) - dcoord = scp.array(P['dcoord']) - ordered_labels = scp.array(P['ivl']) - color_list = scp.array(P['color_list']) - colors = self.get_color_dict(colorscale) - - trace_list = [] - - for i in range(len(icoord)): - # xs and ys are arrays of 4 points that make up the '∩' shapes - # of the dendrogram tree - if self.orientation in ['top', 'bottom']: - xs = icoord[i] - else: - xs = dcoord[i] - - if self.orientation in ['top', 'bottom']: - ys = dcoord[i] - else: - ys = icoord[i] - color_key = color_list[i] - trace = graph_objs.Scatter( - x=np.multiply(self.sign[self.xaxis], xs), - y=np.multiply(self.sign[self.yaxis], ys), - mode='lines', - marker=graph_objs.Marker(color=colors[color_key]) - ) - - try: - x_index = int(self.xaxis[-1]) - except ValueError: - x_index = '' - - try: - y_index = int(self.yaxis[-1]) - except ValueError: - y_index = '' - - trace['xaxis'] = 'x' + x_index - trace['yaxis'] = 'y' + y_index - - trace_list.append(trace) - - return trace_list, icoord, dcoord, ordered_labels, P['leaves'] - - -class _AnnotatedHeatmap(FigureFactory): - """ - Refer to TraceFactory.create_annotated_heatmap() for docstring - """ - def __init__(self, z, x, y, annotation_text, colorscale, - font_colors, reversescale, **kwargs): - from plotly.graph_objs import graph_objs - - self.z = z - if x: - self.x = x - else: - self.x = range(len(z[0])) - if y: - self.y = y - else: - self.y = range(len(z)) - if annotation_text is not None: - self.annotation_text = annotation_text - else: - self.annotation_text = self.z - self.colorscale = colorscale - self.reversescale = reversescale - self.font_colors = font_colors - - def get_text_color(self): - """ - Get font color for annotations. - - The annotated heatmap can feature two text colors: min_text_color and - max_text_color. The min_text_color is applied to annotations for - heatmap values < (max_value - min_value)/2. The user can define these - two colors. Otherwise the colors are defined logically as black or - white depending on the heatmap's colorscale. - - :rtype (string, string) min_text_color, max_text_color: text - color for annotations for heatmap values < - (max_value - min_value)/2 and text color for annotations for - heatmap values >= (max_value - min_value)/2 - """ - # Plotly colorscales ranging from a lighter shade to a darker shade - colorscales = ['Greys', 'Greens', 'Blues', - 'YIGnBu', 'YIOrRd', 'RdBu', - 'Picnic', 'Jet', 'Hot', 'Blackbody', - 'Earth', 'Electric', 'Viridis'] - # Plotly colorscales ranging from a darker shade to a lighter shade - colorscales_reverse = ['Reds'] - if self.font_colors: - min_text_color = self.font_colors[0] - max_text_color = self.font_colors[-1] - elif self.colorscale in colorscales and self.reversescale: - min_text_color = '#000000' - max_text_color = '#FFFFFF' - elif self.colorscale in colorscales: - min_text_color = '#FFFFFF' - max_text_color = '#000000' - elif self.colorscale in colorscales_reverse and self.reversescale: - min_text_color = '#FFFFFF' - max_text_color = '#000000' - elif self.colorscale in colorscales_reverse: - min_text_color = '#000000' - max_text_color = '#FFFFFF' - elif isinstance(self.colorscale, list): - if 'rgb' in self.colorscale[0][1]: - min_col = map(int, - self.colorscale[0][1].strip('rgb()').split(',')) - max_col = map(int, - self.colorscale[-1][1].strip('rgb()').split(',')) - elif '#' in self.colorscale[0][1]: - min_col = FigureFactory._hex_to_rgb(self.colorscale[0][1]) - max_col = FigureFactory._hex_to_rgb(self.colorscale[-1][1]) - else: - min_col = [255, 255, 255] - max_col = [255, 255, 255] - - if (min_col[0]*0.299 + min_col[1]*0.587 + min_col[2]*0.114) > 186: - min_text_color = '#000000' - else: - min_text_color = '#FFFFFF' - if (max_col[0]*0.299 + max_col[1]*0.587 + max_col[2]*0.114) > 186: - max_text_color = '#000000' - else: - max_text_color = '#FFFFFF' - else: - min_text_color = '#000000' - max_text_color = '#000000' - return min_text_color, max_text_color - - def get_z_mid(self): - """ - Get the mid value of z matrix - - :rtype (float) z_avg: average val from z matrix - """ - if _numpy_imported and isinstance(self.z, np.ndarray): - z_min = np.amin(self.z) - z_max = np.amax(self.z) - else: - z_min = min(min(self.z)) - z_max = max(max(self.z)) - z_mid = (z_max+z_min) / 2 - return z_mid - - def make_annotations(self): - """ - Get annotations for each cell of the heatmap with graph_objs.Annotation - - :rtype (list[dict]) annotations: list of annotations for each cell of - the heatmap - """ - from plotly.graph_objs import graph_objs - min_text_color, max_text_color = _AnnotatedHeatmap.get_text_color(self) - z_mid = _AnnotatedHeatmap.get_z_mid(self) - annotations = [] - for n, row in enumerate(self.z): - for m, val in enumerate(row): - font_color = min_text_color if val < z_mid else max_text_color - annotations.append( - graph_objs.Annotation( - text=str(self.annotation_text[n][m]), - x=self.x[m], - y=self.y[n], - xref='x1', - yref='y1', - font=dict(color=font_color), - showarrow=False)) - return annotations - - -class _Table(FigureFactory): - """ - Refer to TraceFactory.create_table() for docstring - """ - def __init__(self, table_text, colorscale, font_colors, index, - index_title, annotation_offset, **kwargs): - from plotly.graph_objs import graph_objs - if _pandas_imported and isinstance(table_text, pd.DataFrame): - headers = table_text.columns.tolist() - table_text_index = table_text.index.tolist() - table_text = table_text.values.tolist() - table_text.insert(0, headers) - if index: - table_text_index.insert(0, index_title) - for i in range(len(table_text)): - table_text[i].insert(0, table_text_index[i]) - self.table_text = table_text - self.colorscale = colorscale - self.font_colors = font_colors - self.index = index - self.annotation_offset = annotation_offset - self.x = range(len(table_text[0])) - self.y = range(len(table_text)) - - def get_table_matrix(self): - """ - Create z matrix to make heatmap with striped table coloring - - :rtype (list[list]) table_matrix: z matrix to make heatmap with striped - table coloring. - """ - header = [0] * len(self.table_text[0]) - odd_row = [.5] * len(self.table_text[0]) - even_row = [1] * len(self.table_text[0]) - table_matrix = [None] * len(self.table_text) - table_matrix[0] = header - for i in range(1, len(self.table_text), 2): - table_matrix[i] = odd_row - for i in range(2, len(self.table_text), 2): - table_matrix[i] = even_row - if self.index: - for array in table_matrix: - array[0] = 0 - return table_matrix - - def get_table_font_color(self): - """ - Fill font-color array. - - Table text color can vary by row so this extends a single color or - creates an array to set a header color and two alternating colors to - create the striped table pattern. - - :rtype (list[list]) all_font_colors: list of font colors for each row - in table. - """ - if len(self.font_colors) == 1: - all_font_colors = self.font_colors*len(self.table_text) - elif len(self.font_colors) == 3: - all_font_colors = list(range(len(self.table_text))) - all_font_colors[0] = self.font_colors[0] - for i in range(1, len(self.table_text), 2): - all_font_colors[i] = self.font_colors[1] - for i in range(2, len(self.table_text), 2): - all_font_colors[i] = self.font_colors[2] - elif len(self.font_colors) == len(self.table_text): - all_font_colors = self.font_colors - else: - all_font_colors = ['#000000']*len(self.table_text) - return all_font_colors - - def make_table_annotations(self): - """ - Generate annotations to fill in table text - - :rtype (list) annotations: list of annotations for each cell of the - table. - """ - from plotly.graph_objs import graph_objs - table_matrix = _Table.get_table_matrix(self) - all_font_colors = _Table.get_table_font_color(self) - annotations = [] - for n, row in enumerate(self.table_text): - for m, val in enumerate(row): - # Bold text in header and index - format_text = ('' + str(val) + '' if n == 0 or - self.index and m < 1 else str(val)) - # Match font color of index to font color of header - font_color = (self.font_colors[0] if self.index and m == 0 - else all_font_colors[n]) - annotations.append( - graph_objs.Annotation( - text=format_text, - x=self.x[m] - self.annotation_offset, - y=self.y[n], - xref='x1', - yref='y1', - align="left", - xanchor="left", - font=dict(color=font_color), - showarrow=False)) - return annotations + def create_violin(*args, **kwargs): + FigureFactory._deprecated('create_violin') + from plotly.figure_factory import create_violin + return create_violin(*args, **kwargs) diff --git a/plotly/utils.py b/plotly/utils.py index 782779c457c..37fdd3a26bb 100644 --- a/plotly/utils.py +++ b/plotly/utils.py @@ -7,7 +7,6 @@ """ from __future__ import absolute_import -import json import os.path import re import sys @@ -15,27 +14,16 @@ import decimal import pytz +from requests.compat import json as _json +from plotly.optional_imports import get_module from . exceptions import PlotlyError -try: - import numpy - _numpy_imported = True -except ImportError: - _numpy_imported = False - -try: - import pandas - _pandas_imported = True -except ImportError: - _pandas_imported = False - -try: - import sage.all - _sage_imported = True -except ImportError: - _sage_imported = False +# Optional imports, may be None for users that only use our core functionality. +numpy = get_module('numpy') +pandas = get_module('pandas') +sage_all = get_module('sage.all') ### incase people are using threading, we lock file reads @@ -51,7 +39,7 @@ def load_json_dict(filename, *args): lock.acquire() with open(filename, "r") as f: try: - data = json.load(f) + data = _json.load(f) if not isinstance(data, dict): data = {} except: @@ -66,7 +54,7 @@ def save_json_dict(filename, json_dict): """Save json to file. Error if path DNE, not a dict, or invalid json.""" if isinstance(json_dict, dict): # this will raise a TypeError if something goes wrong - json_string = json.dumps(json_dict, indent=4) + json_string = _json.dumps(json_dict, indent=4) lock.acquire() with open(filename, "w") as f: f.write(json_string) @@ -112,7 +100,7 @@ class NotEncodable(Exception): pass -class PlotlyJSONEncoder(json.JSONEncoder): +class PlotlyJSONEncoder(_json.JSONEncoder): """ Meant to be passed as the `cls` kwarg to json.dumps(obj, cls=..) @@ -149,7 +137,8 @@ def encode(self, o): # 1. `loads` to switch Infinity, -Infinity, NaN to None # 2. `dumps` again so you get 'null' instead of extended JSON try: - new_o = json.loads(encoded_o, parse_constant=self.coerce_to_strict) + new_o = _json.loads(encoded_o, + parse_constant=self.coerce_to_strict) except ValueError: # invalid separators will fail here. raise a helpful exception @@ -158,10 +147,10 @@ def encode(self, o): "valid JSON separators?" ) else: - return json.dumps(new_o, sort_keys=self.sort_keys, - indent=self.indent, - separators=(self.item_separator, - self.key_separator)) + return _json.dumps(new_o, sort_keys=self.sort_keys, + indent=self.indent, + separators=(self.item_separator, + self.key_separator)) def default(self, obj): """ @@ -210,7 +199,7 @@ def default(self, obj): return encoding_method(obj) except NotEncodable: pass - return json.JSONEncoder.default(self, obj) + return _json.JSONEncoder.default(self, obj) @staticmethod def encode_as_plotly(obj): @@ -231,12 +220,12 @@ def encode_as_list(obj): @staticmethod def encode_as_sage(obj): """Attempt to convert sage.all.RR to floats and sage.all.ZZ to ints""" - if not _sage_imported: + if not sage_all: raise NotEncodable - if obj in sage.all.RR: + if obj in sage_all.RR: return float(obj) - elif obj in sage.all.ZZ: + elif obj in sage_all.ZZ: return int(obj) else: raise NotEncodable @@ -244,7 +233,7 @@ def encode_as_sage(obj): @staticmethod def encode_as_pandas(obj): """Attempt to convert pandas.NaT""" - if not _pandas_imported: + if not pandas: raise NotEncodable if obj is pandas.NaT: @@ -255,7 +244,7 @@ def encode_as_pandas(obj): @staticmethod def encode_as_numpy(obj): """Attempt to convert numpy.ma.core.masked""" - if not _numpy_imported: + if not numpy: raise NotEncodable if obj is numpy.ma.core.masked: diff --git a/plotly/version.py b/plotly/version.py index 84c54b74824..afced14728f 100644 --- a/plotly/version.py +++ b/plotly/version.py @@ -1 +1 @@ -__version__ = '1.13.0' +__version__ = '2.0.0' diff --git a/plotly/widgets/graph_widget.py b/plotly/widgets/graph_widget.py index a359474a7d0..f7eb0a86084 100644 --- a/plotly/widgets/graph_widget.py +++ b/plotly/widgets/graph_widget.py @@ -2,11 +2,11 @@ Module to allow Plotly graphs to interact with IPython widgets. """ -import json import uuid from collections import deque from pkg_resources import resource_string +from requests.compat import json as _json # TODO: protected imports? from IPython.html import widgets @@ -93,7 +93,7 @@ def _handle_msg(self, message): while self._clientMessages: _message = self._clientMessages.popleft() _message['graphId'] = self._graphId - _message = json.dumps(_message) + _message = _json.dumps(_message) self._message = _message if content.get('event', '') in ['click', 'hover', 'zoom']: @@ -131,7 +131,7 @@ def _handle_outgoing_message(self, message): else: message['graphId'] = self._graphId message['uid'] = str(uuid.uuid4()) - self._message = json.dumps(message, cls=utils.PlotlyJSONEncoder) + self._message = _json.dumps(message, cls=utils.PlotlyJSONEncoder) def on_click(self, callback, remove=False): """ Assign a callback to click events propagated diff --git a/setup.py b/setup.py index 9c50bcbf475..1e66057b516 100644 --- a/setup.py +++ b/setup.py @@ -31,8 +31,12 @@ def readme(): ], license='MIT', packages=['plotly', + 'plotly/api', + 'plotly/api/v1', + 'plotly/api/v2', 'plotly/plotly', 'plotly/plotly/chunked_requests', + 'plotly/figure_factory', 'plotly/graph_objs', 'plotly/grid_objs', 'plotly/widgets',