diff --git a/CHANGELOG.md b/CHANGELOG.md
index 39d33089967..564cc774fa1 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -2,6 +2,26 @@
All notable changes to this project will be documented in this file.
This project adheres to [Semantic Versioning](http://semver.org/).
+## [2.0.0]
+
+### Changed
+- `plotly.exceptions.PlotlyRequestException` is *always* raised for network
+failures. Previously either a `PlotlyError`, `PlotlyRequestException`, or a
+`requests.exceptions.ReqestException` could be raised. In particular, scripts
+which depend on `try-except` blocks containing network requests should be
+revisited.
+- `plotly.py:sign_in` now validates to the plotly server specified in your
+ config. If it cannot make a successful request, it raises a `PlotlyError`.
+- `plotly.figure_factory` will raise an `ImportError` if `numpy` is not
+ installed.
+
+### Deprecated
+- `plotly.tools.FigureFactory`. Use `plotly.figure_factory.*`.
+- (optional imports) `plotly.tools._*_imported` It was private anyhow, but now
+it's gone. (e.g., `_numpy_imported`)
+- (plotly v2 helper) `plotly.py._api_v2` It was private anyhow, but now it's
+gone.
+
## [1.13.0] - 2016-01-17
### Added
- Python 3.5 has been added as a tested environment for this package.
diff --git a/circle.yml b/circle.yml
index 62a7bdb505a..29fb8d4fa11 100644
--- a/circle.yml
+++ b/circle.yml
@@ -52,5 +52,9 @@ test:
- sudo chmod -R 444 ${PLOTLY_CONFIG_DIR} && python -c "import plotly"
# test that giving back write permissions works again
- # this also has to pass the test suite that follows
- sudo chmod -R 777 ${PLOTLY_CONFIG_DIR} && python -c "import plotly"
+
+ # test that figure_factory cannot be imported with only core requirements.
+ # since optional requirements is part of the test suite, we don't need to
+ # worry about testing that it *can* be imported in this case.
+ - $(! python -c "import plotly.figure_factory")
diff --git a/optional-requirements.txt b/optional-requirements.txt
index 9dcc04023b1..ada6fb8598d 100644
--- a/optional-requirements.txt
+++ b/optional-requirements.txt
@@ -12,8 +12,9 @@ numpy
# matplotlib==1.3.1
## testing dependencies ##
-nose
-coverage
+coverage==4.3.1
+mock==2.0.0
+nose==1.3.3
## ipython ##
ipython
diff --git a/plotly/api/__init__.py b/plotly/api/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/plotly/api/utils.py b/plotly/api/utils.py
new file mode 100644
index 00000000000..d9d1d21f504
--- /dev/null
+++ b/plotly/api/utils.py
@@ -0,0 +1,41 @@
+from base64 import b64encode
+
+from requests.compat import builtin_str, is_py2
+
+
+def _to_native_string(string, encoding):
+ if isinstance(string, builtin_str):
+ return string
+ if is_py2:
+ return string.encode(encoding)
+ return string.decode(encoding)
+
+
+def to_native_utf8_string(string):
+ return _to_native_string(string, 'utf-8')
+
+
+def to_native_ascii_string(string):
+ return _to_native_string(string, 'ascii')
+
+
+def basic_auth(username, password):
+ """
+ Creates the basic auth value to be used in an authorization header.
+
+ This is mostly copied from the requests library.
+
+ :param (str) username: A Plotly username.
+ :param (str) password: The password for the given Plotly username.
+ :returns: (str) An 'authorization' header for use in a request header.
+
+ """
+ if isinstance(username, str):
+ username = username.encode('latin1')
+
+ if isinstance(password, str):
+ password = password.encode('latin1')
+
+ return 'Basic ' + to_native_ascii_string(
+ b64encode(b':'.join((username, password))).strip()
+ )
diff --git a/plotly/api/v1/__init__.py b/plotly/api/v1/__init__.py
new file mode 100644
index 00000000000..a43ff61f4c8
--- /dev/null
+++ b/plotly/api/v1/__init__.py
@@ -0,0 +1,3 @@
+from __future__ import absolute_import
+
+from plotly.api.v1.clientresp import clientresp
diff --git a/plotly/api/v1/clientresp.py b/plotly/api/v1/clientresp.py
new file mode 100644
index 00000000000..c3af66c6b1c
--- /dev/null
+++ b/plotly/api/v1/clientresp.py
@@ -0,0 +1,44 @@
+"""Interface to deprecated /clientresp API. Subject to deletion."""
+from __future__ import absolute_import
+
+import warnings
+
+from requests.compat import json as _json
+
+from plotly import config, utils, version
+from plotly.api.v1.utils import request
+
+
+def clientresp(data, **kwargs):
+ """
+ Deprecated endpoint, still used because it can parse data out of a plot.
+
+ When we get around to forcing users to create grids and then create plots,
+ we can finally get rid of this.
+
+ :param (list) data: The data array from a figure.
+
+ """
+ creds = config.get_credentials()
+ cfg = config.get_config()
+
+ dumps_kwargs = {'sort_keys': True, 'cls': utils.PlotlyJSONEncoder}
+
+ payload = {
+ 'platform': 'python', 'version': version.__version__,
+ 'args': _json.dumps(data, **dumps_kwargs),
+ 'un': creds['username'], 'key': creds['api_key'], 'origin': 'plot',
+ 'kwargs': _json.dumps(kwargs, **dumps_kwargs)
+ }
+
+ url = '{plotly_domain}/clientresp'.format(**cfg)
+ response = request('post', url, data=payload)
+
+ # Old functionality, just keeping it around.
+ parsed_content = response.json()
+ if parsed_content.get('warning'):
+ warnings.warn(parsed_content['warning'])
+ if parsed_content.get('message'):
+ print(parsed_content['message'])
+
+ return response
diff --git a/plotly/api/v1/utils.py b/plotly/api/v1/utils.py
new file mode 100644
index 00000000000..abfdf745c3e
--- /dev/null
+++ b/plotly/api/v1/utils.py
@@ -0,0 +1,87 @@
+from __future__ import absolute_import
+
+import requests
+from requests.exceptions import RequestException
+
+from plotly import config, exceptions
+from plotly.api.utils import basic_auth
+
+
+def validate_response(response):
+ """
+ Raise a helpful PlotlyRequestError for failed requests.
+
+ :param (requests.Response) response: A Response object from an api request.
+ :raises: (PlotlyRequestError) If the request failed for any reason.
+ :returns: (None)
+
+ """
+ content = response.content
+ status_code = response.status_code
+ try:
+ parsed_content = response.json()
+ except ValueError:
+ message = content if content else 'No Content'
+ raise exceptions.PlotlyRequestError(message, status_code, content)
+
+ message = ''
+ if isinstance(parsed_content, dict):
+ error = parsed_content.get('error')
+ if error:
+ message = error
+ else:
+ if response.ok:
+ return
+ if not message:
+ message = content if content else 'No Content'
+
+ raise exceptions.PlotlyRequestError(message, status_code, content)
+
+
+def get_headers():
+ """
+ Using session credentials/config, get headers for a v1 API request.
+
+ Users may have their own proxy layer and so we free up the `authorization`
+ header for this purpose (instead adding the user authorization in a new
+ `plotly-authorization` header). See pull #239.
+
+ :returns: (dict) Headers to add to a requests.request call.
+
+ """
+ headers = {}
+ creds = config.get_credentials()
+ proxy_auth = basic_auth(creds['proxy_username'], creds['proxy_password'])
+
+ if config.get_config()['plotly_proxy_authorization']:
+ headers['authorization'] = proxy_auth
+
+ return headers
+
+
+def request(method, url, **kwargs):
+ """
+ Central place to make any v1 api request.
+
+ :param (str) method: The request method ('get', 'put', 'delete', ...).
+ :param (str) url: The full api url to make the request to.
+ :param kwargs: These are passed along to requests.
+ :return: (requests.Response) The response directly from requests.
+
+ """
+ if kwargs.get('json', None) is not None:
+ # See plotly.api.v2.utils.request for examples on how to do this.
+ raise exceptions.PlotlyError('V1 API does not handle arbitrary json.')
+ kwargs['headers'] = dict(kwargs.get('headers', {}), **get_headers())
+ kwargs['verify'] = config.get_config()['plotly_ssl_verification']
+ try:
+ response = requests.request(method, url, **kwargs)
+ except RequestException as e:
+ # The message can be an exception. E.g., MaxRetryError.
+ message = str(getattr(e, 'message', 'No message'))
+ response = getattr(e, 'response', None)
+ status_code = response.status_code if response else None
+ content = response.content if response else 'No content'
+ raise exceptions.PlotlyRequestError(message, status_code, content)
+ validate_response(response)
+ return response
diff --git a/plotly/api/v2/__init__.py b/plotly/api/v2/__init__.py
new file mode 100644
index 00000000000..8424927d1c6
--- /dev/null
+++ b/plotly/api/v2/__init__.py
@@ -0,0 +1,4 @@
+from __future__ import absolute_import
+
+from plotly.api.v2 import (files, folders, grids, images, plot_schema, plots,
+ users)
diff --git a/plotly/api/v2/files.py b/plotly/api/v2/files.py
new file mode 100644
index 00000000000..650ab48fc85
--- /dev/null
+++ b/plotly/api/v2/files.py
@@ -0,0 +1,85 @@
+"""Interface to Plotly's /v2/files endpoints."""
+from __future__ import absolute_import
+
+from plotly.api.v2.utils import build_url, make_params, request
+
+RESOURCE = 'files'
+
+
+def retrieve(fid, share_key=None):
+ """
+ Retrieve a general file from Plotly.
+
+ :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`.
+ :param (str) share_key: The secret key granting 'read' access if private.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, id=fid)
+ params = make_params(share_key=share_key)
+ return request('get', url, params=params)
+
+
+def update(fid, body):
+ """
+ Update a general file from Plotly.
+
+ :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`.
+ :param (dict) body: A mapping of body param names to values.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, id=fid)
+ return request('put', url, json=body)
+
+
+def trash(fid):
+ """
+ Soft-delete a general file from Plotly. (Can be undone with 'restore').
+
+ :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, id=fid, route='trash')
+ return request('post', url)
+
+
+def restore(fid):
+ """
+ Restore a trashed, general file from Plotly. See 'trash'.
+
+ :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, id=fid, route='restore')
+ return request('post', url)
+
+
+def permanent_delete(fid):
+ """
+ Permanently delete a trashed, general file from Plotly. See 'trash'.
+
+ :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, id=fid, route='permanent_delete')
+ return request('delete', url)
+
+
+def lookup(path, parent=None, user=None, exists=None):
+ """
+ Retrieve a general file from Plotly without needing a fid.
+
+ :param (str) path: The '/'-delimited path specifying the file location.
+ :param (int) parent: Parent id, an integer, which the path is relative to.
+ :param (str) user: The username to target files for. Defaults to requestor.
+ :param (bool) exists: If True, don't return the full file, just a flag.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, route='lookup')
+ params = make_params(path=path, parent=parent, user=user, exists=exists)
+ return request('get', url, params=params)
diff --git a/plotly/api/v2/folders.py b/plotly/api/v2/folders.py
new file mode 100644
index 00000000000..2dcf84670e7
--- /dev/null
+++ b/plotly/api/v2/folders.py
@@ -0,0 +1,103 @@
+"""Interface to Plotly's /v2/folders endpoints."""
+from __future__ import absolute_import
+
+from plotly.api.v2.utils import build_url, make_params, request
+
+RESOURCE = 'folders'
+
+
+def create(body):
+ """
+ Create a new folder.
+
+ :param (dict) body: A mapping of body param names to values.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE)
+ return request('post', url, json=body)
+
+
+def retrieve(fid, share_key=None):
+ """
+ Retrieve a folder from Plotly.
+
+ :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`.
+ :param (str) share_key: The secret key granting 'read' access if private.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, id=fid)
+ params = make_params(share_key=share_key)
+ return request('get', url, params=params)
+
+
+def update(fid, body):
+ """
+ Update a folder from Plotly.
+
+ :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`.
+ :param (dict) body: A mapping of body param names to values.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, id=fid)
+ return request('put', url, json=body)
+
+
+def trash(fid):
+ """
+ Soft-delete a folder from Plotly. (Can be undone with 'restore').
+
+ This action is recursively done on files inside the folder.
+
+ :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, id=fid, route='trash')
+ return request('post', url)
+
+
+def restore(fid):
+ """
+ Restore a trashed folder from Plotly. See 'trash'.
+
+ This action is recursively done on files inside the folder.
+
+ :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, id=fid, route='restore')
+ return request('post', url)
+
+
+def permanent_delete(fid):
+ """
+ Permanently delete a trashed folder file from Plotly. See 'trash'.
+
+ This action is recursively done on files inside the folder.
+
+ :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, id=fid, route='permanent_delete')
+ return request('delete', url)
+
+
+def lookup(path, parent=None, user=None, exists=None):
+ """
+ Retrieve a folder file from Plotly without needing a fid.
+
+ :param (str) path: The '/'-delimited path specifying the file location.
+ :param (int) parent: Parent id, an integer, which the path is relative to.
+ :param (str) user: The username to target files for. Defaults to requestor.
+ :param (bool) exists: If True, don't return the full file, just a flag.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, route='lookup')
+ params = make_params(path=path, parent=parent, user=user, exists=exists)
+ return request('get', url, params=params)
diff --git a/plotly/api/v2/grids.py b/plotly/api/v2/grids.py
new file mode 100644
index 00000000000..144ec3bd23f
--- /dev/null
+++ b/plotly/api/v2/grids.py
@@ -0,0 +1,180 @@
+"""Interface to Plotly's /v2/grids endpoints."""
+from __future__ import absolute_import
+
+from plotly.api.v2.utils import build_url, make_params, request
+
+RESOURCE = 'grids'
+
+
+def create(body):
+ """
+ Create a new grid.
+
+ :param (dict) body: A mapping of body param names to values.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE)
+ return request('post', url, json=body)
+
+
+def retrieve(fid, share_key=None):
+ """
+ Retrieve a grid from Plotly.
+
+ :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`.
+ :param (str) share_key: The secret key granting 'read' access if private.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, id=fid)
+ params = make_params(share_key=share_key)
+ return request('get', url, params=params)
+
+
+def content(fid, share_key=None):
+ """
+ Retrieve full content for the grid (normal retrieve only yields preview)
+
+ :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`.
+ :param (str) share_key: The secret key granting 'read' access if private.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, id=fid, route='content')
+ params = make_params(share_key=share_key)
+ return request('get', url, params=params)
+
+
+def update(fid, body):
+ """
+ Update a grid from Plotly.
+
+ :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`.
+ :param (dict) body: A mapping of body param names to values.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, id=fid)
+ return request('put', url, json=body)
+
+
+def trash(fid):
+ """
+ Soft-delete a grid from Plotly. (Can be undone with 'restore').
+
+ :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, id=fid, route='trash')
+ return request('post', url)
+
+
+def restore(fid):
+ """
+ Restore a trashed grid from Plotly. See 'trash'.
+
+ :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, id=fid, route='restore')
+ return request('post', url)
+
+
+def permanent_delete(fid):
+ """
+ Permanently delete a trashed grid file from Plotly. See 'trash'.
+
+ :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, id=fid, route='permanent_delete')
+ return request('delete', url)
+
+
+def lookup(path, parent=None, user=None, exists=None):
+ """
+ Retrieve a grid file from Plotly without needing a fid.
+
+ :param (str) path: The '/'-delimited path specifying the file location.
+ :param (int) parent: Parent id, an integer, which the path is relative to.
+ :param (str) user: The username to target files for. Defaults to requestor.
+ :param (bool) exists: If True, don't return the full file, just a flag.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, route='lookup')
+ params = make_params(path=path, parent=parent, user=user, exists=exists)
+ return request('get', url, params=params)
+
+
+def col_create(fid, body):
+ """
+ Create a new column (or columns) inside a grid.
+
+ :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`.
+ :param (dict) body: A mapping of body param names to values.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, id=fid, route='col')
+ return request('post', url, json=body)
+
+
+def col_retrieve(fid, uid):
+ """
+ Retrieve a column (or columns) from a grid.
+
+ :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`.
+ :param (str) uid: A ','-concatenated string of column uids in the grid.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, id=fid, route='col')
+ params = make_params(uid=uid)
+ return request('get', url, params=params)
+
+
+def col_update(fid, uid, body):
+ """
+ Update a column (or columns) from a grid.
+
+ :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`.
+ :param (str) uid: A ','-concatenated string of column uids in the grid.
+ :param (dict) body: A mapping of body param names to values.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, id=fid, route='col')
+ params = make_params(uid=uid)
+ return request('put', url, json=body, params=params)
+
+
+def col_delete(fid, uid):
+ """
+ Permanently delete a column (or columns) from a grid.
+
+ :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`.
+ :param (str) uid: A ','-concatenated string of column uids in the grid.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, id=fid, route='col')
+ params = make_params(uid=uid)
+ return request('delete', url, params=params)
+
+
+def row(fid, body):
+ """
+ Append rows to a grid.
+
+ :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`.
+ :param (dict) body: A mapping of body param names to values.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, id=fid, route='row')
+ return request('post', url, json=body)
diff --git a/plotly/api/v2/images.py b/plotly/api/v2/images.py
new file mode 100644
index 00000000000..4c9d1816081
--- /dev/null
+++ b/plotly/api/v2/images.py
@@ -0,0 +1,18 @@
+"""Interface to Plotly's /v2/images endpoints."""
+from __future__ import absolute_import
+
+from plotly.api.v2.utils import build_url, request
+
+RESOURCE = 'images'
+
+
+def create(body):
+ """
+ Generate an image (which does not get saved on Plotly).
+
+ :param (dict) body: A mapping of body param names to values.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE)
+ return request('post', url, json=body)
diff --git a/plotly/api/v2/plot_schema.py b/plotly/api/v2/plot_schema.py
new file mode 100644
index 00000000000..4edbc0a707b
--- /dev/null
+++ b/plotly/api/v2/plot_schema.py
@@ -0,0 +1,19 @@
+"""Interface to Plotly's /v2/plot-schema endpoints."""
+from __future__ import absolute_import
+
+from plotly.api.v2.utils import build_url, make_params, request
+
+RESOURCE = 'plot-schema'
+
+
+def retrieve(sha1, **kwargs):
+ """
+ Retrieve the most up-to-date copy of the plot-schema wrt the given hash.
+
+ :param (str) sha1: The last-known hash of the plot-schema (or '').
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE)
+ params = make_params(sha1=sha1)
+ return request('get', url, params=params, **kwargs)
diff --git a/plotly/api/v2/plots.py b/plotly/api/v2/plots.py
new file mode 100644
index 00000000000..da9f2d9e395
--- /dev/null
+++ b/plotly/api/v2/plots.py
@@ -0,0 +1,119 @@
+"""Interface to Plotly's /v2/plots endpoints."""
+from __future__ import absolute_import
+
+from plotly.api.v2.utils import build_url, make_params, request
+
+RESOURCE = 'plots'
+
+
+def create(body):
+ """
+ Create a new plot.
+
+ :param (dict) body: A mapping of body param names to values.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE)
+ return request('post', url, json=body)
+
+
+def retrieve(fid, share_key=None):
+ """
+ Retrieve a plot from Plotly.
+
+ :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`.
+ :param (str) share_key: The secret key granting 'read' access if private.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, id=fid)
+ params = make_params(share_key=share_key)
+ return request('get', url, params=params)
+
+
+def content(fid, share_key=None, inline_data=None, map_data=None):
+ """
+ Retrieve the *figure* for a Plotly plot file.
+
+ :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`.
+ :param (str) share_key: The secret key granting 'read' access if private.
+ :param (bool) inline_data: If True, include the data arrays with the plot.
+ :param (str) map_data: Currently only accepts 'unreadable' to return a
+ mapping of grid-fid: grid. This is useful if you
+ want to maintain structure between the plot and
+ referenced grids when you have READ access to the
+ plot, but you don't have READ access to the
+ underlying grids.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, id=fid, route='content')
+ params = make_params(share_key=share_key, inline_data=inline_data,
+ map_data=map_data)
+ return request('get', url, params=params)
+
+
+def update(fid, body):
+ """
+ Update a plot from Plotly.
+
+ :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`.
+ :param (dict) body: A mapping of body param names to values.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, id=fid)
+ return request('put', url, json=body)
+
+
+def trash(fid):
+ """
+ Soft-delete a plot from Plotly. (Can be undone with 'restore').
+
+ :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, id=fid, route='trash')
+ return request('post', url)
+
+
+def restore(fid):
+ """
+ Restore a trashed plot from Plotly. See 'trash'.
+
+ :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, id=fid, route='restore')
+ return request('post', url)
+
+
+def permanent_delete(fid, params=None):
+ """
+ Permanently delete a trashed plot file from Plotly. See 'trash'.
+
+ :param (str) fid: The `{username}:{idlocal}` identifier. E.g. `foo:88`.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, id=fid, route='permanent_delete')
+ return request('delete', url, params=params)
+
+
+def lookup(path, parent=None, user=None, exists=None):
+ """
+ Retrieve a plot file from Plotly without needing a fid.
+
+ :param (str) path: The '/'-delimited path specifying the file location.
+ :param (int) parent: Parent id, an integer, which the path is relative to.
+ :param (str) user: The username to target files for. Defaults to requestor.
+ :param (bool) exists: If True, don't return the full file, just a flag.
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, route='lookup')
+ params = make_params(path=path, parent=parent, user=user, exists=exists)
+ return request('get', url, params=params)
diff --git a/plotly/api/v2/users.py b/plotly/api/v2/users.py
new file mode 100644
index 00000000000..cdfaf51c488
--- /dev/null
+++ b/plotly/api/v2/users.py
@@ -0,0 +1,17 @@
+"""Interface to Plotly's /v2/files endpoints."""
+from __future__ import absolute_import
+
+from plotly.api.v2.utils import build_url, request
+
+RESOURCE = 'users'
+
+
+def current():
+ """
+ Retrieve information on the logged-in user from Plotly.
+
+ :returns: (requests.Response) Returns response directly from requests.
+
+ """
+ url = build_url(RESOURCE, route='current')
+ return request('get', url)
diff --git a/plotly/api/v2/utils.py b/plotly/api/v2/utils.py
new file mode 100644
index 00000000000..21bd0ddd016
--- /dev/null
+++ b/plotly/api/v2/utils.py
@@ -0,0 +1,154 @@
+from __future__ import absolute_import
+
+import requests
+from requests.compat import json as _json
+from requests.exceptions import RequestException
+
+from plotly import config, exceptions, version, utils
+from plotly.api.utils import basic_auth
+
+
+def make_params(**kwargs):
+ """
+ Helper to create a params dict, skipping undefined entries.
+
+ :returns: (dict) A params dict to pass to `request`.
+
+ """
+ return {k: v for k, v in kwargs.items() if v is not None}
+
+
+def build_url(resource, id='', route=''):
+ """
+ Create a url for a request on a V2 resource.
+
+ :param (str) resource: E.g., 'files', 'plots', 'grids', etc.
+ :param (str) id: The unique identifier for the resource.
+ :param (str) route: Detail/list route. E.g., 'restore', 'lookup', etc.
+ :return: (str) The url.
+
+ """
+ base = config.get_config()['plotly_api_domain']
+ formatter = {'base': base, 'resource': resource, 'id': id, 'route': route}
+
+ # Add path to base url depending on the input params. Note that `route`
+ # can refer to a 'list' or a 'detail' route. Since it cannot refer to
+ # both at the same time, it's overloaded in this function.
+ if id:
+ if route:
+ url = '{base}/v2/{resource}/{id}/{route}'.format(**formatter)
+ else:
+ url = '{base}/v2/{resource}/{id}'.format(**formatter)
+ else:
+ if route:
+ url = '{base}/v2/{resource}/{route}'.format(**formatter)
+ else:
+ url = '{base}/v2/{resource}'.format(**formatter)
+
+ return url
+
+
+def validate_response(response):
+ """
+ Raise a helpful PlotlyRequestError for failed requests.
+
+ :param (requests.Response) response: A Response object from an api request.
+ :raises: (PlotlyRequestError) If the request failed for any reason.
+ :returns: (None)
+
+ """
+ if response.ok:
+ return
+
+ content = response.content
+ status_code = response.status_code
+ try:
+ parsed_content = response.json()
+ except ValueError:
+ message = content if content else 'No Content'
+ raise exceptions.PlotlyRequestError(message, status_code, content)
+
+ message = ''
+ if isinstance(parsed_content, dict):
+ errors = parsed_content.get('errors', [])
+ messages = [error.get('message') for error in errors]
+ message = '\n'.join([msg for msg in messages if msg])
+ if not message:
+ message = content if content else 'No Content'
+
+ raise exceptions.PlotlyRequestError(message, status_code, content)
+
+
+def get_headers():
+ """
+ Using session credentials/config, get headers for a V2 API request.
+
+ Users may have their own proxy layer and so we free up the `authorization`
+ header for this purpose (instead adding the user authorization in a new
+ `plotly-authorization` header). See pull #239.
+
+ :returns: (dict) Headers to add to a requests.request call.
+
+ """
+ creds = config.get_credentials()
+
+ headers = {
+ 'plotly-client-platform': 'python {}'.format(version.__version__),
+ 'content-type': 'application/json'
+ }
+
+ plotly_auth = basic_auth(creds['username'], creds['api_key'])
+ proxy_auth = basic_auth(creds['proxy_username'], creds['proxy_password'])
+
+ if config.get_config()['plotly_proxy_authorization']:
+ headers['authorization'] = proxy_auth
+ if creds['username'] and creds['api_key']:
+ headers['plotly-authorization'] = plotly_auth
+ else:
+ if creds['username'] and creds['api_key']:
+ headers['authorization'] = plotly_auth
+
+ return headers
+
+
+def request(method, url, **kwargs):
+ """
+ Central place to make any api v2 api request.
+
+ :param (str) method: The request method ('get', 'put', 'delete', ...).
+ :param (str) url: The full api url to make the request to.
+ :param kwargs: These are passed along (but possibly mutated) to requests.
+ :return: (requests.Response) The response directly from requests.
+
+ """
+ kwargs['headers'] = dict(kwargs.get('headers', {}), **get_headers())
+
+ # Change boolean params to lowercase strings. E.g., `True` --> `'true'`.
+ # Just change the value so that requests handles query string creation.
+ if isinstance(kwargs.get('params'), dict):
+ kwargs['params'] = kwargs['params'].copy()
+ for key in kwargs['params']:
+ if isinstance(kwargs['params'][key], bool):
+ kwargs['params'][key] = _json.dumps(kwargs['params'][key])
+
+ # We have a special json encoding class for non-native objects.
+ if kwargs.get('json') is not None:
+ if kwargs.get('data'):
+ raise exceptions.PlotlyError('Cannot supply data and json kwargs.')
+ kwargs['data'] = _json.dumps(kwargs.pop('json'), sort_keys=True,
+ cls=utils.PlotlyJSONEncoder)
+
+ # The config file determines whether reuqests should *verify*.
+ kwargs['verify'] = config.get_config()['plotly_ssl_verification']
+
+ try:
+ response = requests.request(method, url, **kwargs)
+ except RequestException as e:
+ # The message can be an exception. E.g., MaxRetryError.
+ message = str(getattr(e, 'message', 'No message'))
+ response = getattr(e, 'response', None)
+ status_code = response.status_code if response else None
+ content = response.content if response else 'No content'
+ raise exceptions.PlotlyRequestError(message, status_code, content)
+ validate_response(response)
+ return response
diff --git a/plotly/config.py b/plotly/config.py
new file mode 100644
index 00000000000..dc1b8e28654
--- /dev/null
+++ b/plotly/config.py
@@ -0,0 +1,35 @@
+"""
+Merges and prioritizes file/session config and credentials.
+
+This is promoted to its own module to simplify imports.
+
+"""
+from __future__ import absolute_import
+
+from plotly import session, tools
+
+
+def get_credentials():
+ """Returns the credentials that will be sent to plotly."""
+ credentials = tools.get_credentials_file()
+ session_credentials = session.get_session_credentials()
+ for credentials_key in credentials:
+
+ # checking for not false, but truthy value here is the desired behavior
+ session_value = session_credentials.get(credentials_key)
+ if session_value is False or session_value:
+ credentials[credentials_key] = session_value
+ return credentials
+
+
+def get_config():
+ """Returns either module config or file config."""
+ config = tools.get_config_file()
+ session_config = session.get_session_config()
+ for config_key in config:
+
+ # checking for not false, but truthy value here is the desired behavior
+ session_value = session_config.get(config_key)
+ if session_value is False or session_value:
+ config[config_key] = session_value
+ return config
diff --git a/plotly/exceptions.py b/plotly/exceptions.py
index 8f7c8920454..05df864497f 100644
--- a/plotly/exceptions.py
+++ b/plotly/exceptions.py
@@ -5,7 +5,9 @@
A module that contains plotly's exception hierarchy.
"""
-import json
+from __future__ import absolute_import
+
+from plotly.api.utils import to_native_utf8_string
# Base Plotly Error
@@ -18,29 +20,12 @@ class InputError(PlotlyError):
class PlotlyRequestError(PlotlyError):
- def __init__(self, requests_exception):
- self.status_code = requests_exception.response.status_code
- self.HTTPError = requests_exception
- content_type = requests_exception.response.headers['content-type']
- if 'json' in content_type:
- content = requests_exception.response.content
- if content != '':
- res_payload = json.loads(
- requests_exception.response.content.decode('utf8')
- )
- if 'detail' in res_payload:
- self.message = res_payload['detail']
- else:
- self.message = ''
- else:
- self.message = ''
- elif content_type == 'text/plain':
- self.message = requests_exception.response.content
- else:
- try:
- self.message = requests_exception.message
- except AttributeError:
- self.message = 'unknown error'
+ """General API error. Raised for *all* failed requests."""
+
+ def __init__(self, message, status_code, content):
+ self.message = to_native_utf8_string(message)
+ self.status_code = status_code
+ self.content = content
def __str__(self):
return self.message
diff --git a/plotly/figure_factory/_2d_density.py b/plotly/figure_factory/_2d_density.py
new file mode 100644
index 00000000000..3be0bb06af1
--- /dev/null
+++ b/plotly/figure_factory/_2d_density.py
@@ -0,0 +1,166 @@
+from __future__ import absolute_import
+
+from numbers import Number
+
+from plotly import exceptions
+from plotly.figure_factory import utils
+from plotly.graph_objs import graph_objs
+
+
+def make_linear_colorscale(colors):
+ """
+ Makes a list of colors into a colorscale-acceptable form
+
+ For documentation regarding to the form of the output, see
+ https://plot.ly/python/reference/#mesh3d-colorscale
+ """
+ scale = 1. / (len(colors) - 1)
+ return [[i * scale, color] for i, color in enumerate(colors)]
+
+
+def create_2d_density(x, y, colorscale='Earth', ncontours=20,
+ hist_color=(0, 0, 0.5), point_color=(0, 0, 0.5),
+ point_size=2, title='2D Density Plot',
+ height=600, width=600):
+ """
+ Returns figure for a 2D density plot
+
+ :param (list|array) x: x-axis data for plot generation
+ :param (list|array) y: y-axis data for plot generation
+ :param (str|tuple|list) colorscale: either a plotly scale name, an rgb
+ or hex color, a color tuple or a list or tuple of colors. An rgb
+ color is of the form 'rgb(x, y, z)' where x, y, z belong to the
+ interval [0, 255] and a color tuple is a tuple of the form
+ (a, b, c) where a, b and c belong to [0, 1]. If colormap is a
+ list, it must contain the valid color types aforementioned as its
+ members.
+ :param (int) ncontours: the number of 2D contours to draw on the plot
+ :param (str) hist_color: the color of the plotted histograms
+ :param (str) point_color: the color of the scatter points
+ :param (str) point_size: the color of the scatter points
+ :param (str) title: set the title for the plot
+ :param (float) height: the height of the chart
+ :param (float) width: the width of the chart
+
+ Example 1: Simple 2D Density Plot
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory create_2d_density
+
+ import numpy as np
+
+ # Make data points
+ t = np.linspace(-1,1.2,2000)
+ x = (t**3)+(0.3*np.random.randn(2000))
+ y = (t**6)+(0.3*np.random.randn(2000))
+
+ # Create a figure
+ fig = create_2D_density(x, y)
+
+ # Plot the data
+ py.iplot(fig, filename='simple-2d-density')
+ ```
+
+ Example 2: Using Parameters
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory create_2d_density
+
+ import numpy as np
+
+ # Make data points
+ t = np.linspace(-1,1.2,2000)
+ x = (t**3)+(0.3*np.random.randn(2000))
+ y = (t**6)+(0.3*np.random.randn(2000))
+
+ # Create custom colorscale
+ colorscale = ['#7A4579', '#D56073', 'rgb(236,158,105)',
+ (1, 1, 0.2), (0.98,0.98,0.98)]
+
+ # Create a figure
+ fig = create_2D_density(
+ x, y, colorscale=colorscale,
+ hist_color='rgb(255, 237, 222)', point_size=3)
+
+ # Plot the data
+ py.iplot(fig, filename='use-parameters')
+ ```
+ """
+
+ # validate x and y are filled with numbers only
+ for array in [x, y]:
+ if not all(isinstance(element, Number) for element in array):
+ raise exceptions.PlotlyError(
+ "All elements of your 'x' and 'y' lists must be numbers."
+ )
+
+ # validate x and y are the same length
+ if len(x) != len(y):
+ raise exceptions.PlotlyError(
+ "Both lists 'x' and 'y' must be the same length."
+ )
+
+ colorscale = utils.validate_colors(colorscale, 'rgb')
+ colorscale = make_linear_colorscale(colorscale)
+
+ # validate hist_color and point_color
+ hist_color = utils.validate_colors(hist_color, 'rgb')
+ point_color = utils.validate_colors(point_color, 'rgb')
+
+ trace1 = graph_objs.Scatter(
+ x=x, y=y, mode='markers', name='points',
+ marker=dict(
+ color=point_color[0],
+ size=point_size,
+ opacity=0.4
+ )
+ )
+ trace2 = graph_objs.Histogram2dcontour(
+ x=x, y=y, name='density', ncontours=ncontours,
+ colorscale=colorscale, reversescale=True, showscale=False
+ )
+ trace3 = graph_objs.Histogram(
+ x=x, name='x density',
+ marker=dict(color=hist_color[0]), yaxis='y2'
+ )
+ trace4 = graph_objs.Histogram(
+ y=y, name='y density',
+ marker=dict(color=hist_color[0]), xaxis='x2'
+ )
+ data = [trace1, trace2, trace3, trace4]
+
+ layout = graph_objs.Layout(
+ showlegend=False,
+ autosize=False,
+ title=title,
+ height=height,
+ width=width,
+ xaxis=dict(
+ domain=[0, 0.85],
+ showgrid=False,
+ zeroline=False
+ ),
+ yaxis=dict(
+ domain=[0, 0.85],
+ showgrid=False,
+ zeroline=False
+ ),
+ margin=dict(
+ t=50
+ ),
+ hovermode='closest',
+ bargap=0,
+ xaxis2=dict(
+ domain=[0.85, 1],
+ showgrid=False,
+ zeroline=False
+ ),
+ yaxis2=dict(
+ domain=[0.85, 1],
+ showgrid=False,
+ zeroline=False
+ )
+ )
+
+ fig = graph_objs.Figure(data=data, layout=layout)
+ return fig
diff --git a/plotly/figure_factory/__init__.py b/plotly/figure_factory/__init__.py
new file mode 100644
index 00000000000..153ac6d657b
--- /dev/null
+++ b/plotly/figure_factory/__init__.py
@@ -0,0 +1,18 @@
+from __future__ import absolute_import
+
+# Require that numpy exists for figure_factory
+import numpy
+
+from plotly.figure_factory._2d_density import create_2d_density
+from plotly.figure_factory._annotated_heatmap import create_annotated_heatmap
+from plotly.figure_factory._candlestick import create_candlestick
+from plotly.figure_factory._dendrogram import create_dendrogram
+from plotly.figure_factory._distplot import create_distplot
+from plotly.figure_factory._gantt import create_gantt
+from plotly.figure_factory._ohlc import create_ohlc
+from plotly.figure_factory._quiver import create_quiver
+from plotly.figure_factory._scatterplot import create_scatterplotmatrix
+from plotly.figure_factory._streamline import create_streamline
+from plotly.figure_factory._table import create_table
+from plotly.figure_factory._trisurf import create_trisurf
+from plotly.figure_factory._violin import create_violin
diff --git a/plotly/figure_factory/_annotated_heatmap.py b/plotly/figure_factory/_annotated_heatmap.py
new file mode 100644
index 00000000000..36fc1ff7c16
--- /dev/null
+++ b/plotly/figure_factory/_annotated_heatmap.py
@@ -0,0 +1,239 @@
+from __future__ import absolute_import
+
+from plotly import exceptions, optional_imports
+from plotly.figure_factory import utils
+from plotly.graph_objs import graph_objs
+
+# Optional imports, may be None for users that only use our core functionality.
+np = optional_imports.get_module('numpy')
+
+
+def validate_annotated_heatmap(z, x, y, annotation_text):
+ """
+ Annotated-heatmap-specific validations
+
+ Check that if a text matrix is supplied, it has the same
+ dimensions as the z matrix.
+
+ See FigureFactory.create_annotated_heatmap() for params
+
+ :raises: (PlotlyError) If z and text matrices do not have the same
+ dimensions.
+ """
+ if annotation_text is not None and isinstance(annotation_text, list):
+ utils.validate_equal_length(z, annotation_text)
+ for lst in range(len(z)):
+ if len(z[lst]) != len(annotation_text[lst]):
+ raise exceptions.PlotlyError("z and text should have the "
+ "same dimensions")
+
+ if x:
+ if len(x) != len(z[0]):
+ raise exceptions.PlotlyError("oops, the x list that you "
+ "provided does not match the "
+ "width of your z matrix ")
+
+ if y:
+ if len(y) != len(z):
+ raise exceptions.PlotlyError("oops, the y list that you "
+ "provided does not match the "
+ "length of your z matrix ")
+
+
+def create_annotated_heatmap(z, x=None, y=None, annotation_text=None,
+ colorscale='RdBu', font_colors=None,
+ showscale=False, reversescale=False,
+ **kwargs):
+ """
+ BETA function that creates annotated heatmaps
+
+ This function adds annotations to each cell of the heatmap.
+
+ :param (list[list]|ndarray) z: z matrix to create heatmap.
+ :param (list) x: x axis labels.
+ :param (list) y: y axis labels.
+ :param (list[list]|ndarray) annotation_text: Text strings for
+ annotations. Should have the same dimensions as the z matrix. If no
+ text is added, the values of the z matrix are annotated. Default =
+ z matrix values.
+ :param (list|str) colorscale: heatmap colorscale.
+ :param (list) font_colors: List of two color strings: [min_text_color,
+ max_text_color] where min_text_color is applied to annotations for
+ heatmap values < (max_value - min_value)/2. If font_colors is not
+ defined, the colors are defined logically as black or white
+ depending on the heatmap's colorscale.
+ :param (bool) showscale: Display colorscale. Default = False
+ :param kwargs: kwargs passed through plotly.graph_objs.Heatmap.
+ These kwargs describe other attributes about the annotated Heatmap
+ trace such as the colorscale. For more information on valid kwargs
+ call help(plotly.graph_objs.Heatmap)
+
+ Example 1: Simple annotated heatmap with default configuration
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory create_annotated_heatmap
+
+ z = [[0.300000, 0.00000, 0.65, 0.300000],
+ [1, 0.100005, 0.45, 0.4300],
+ [0.300000, 0.00000, 0.65, 0.300000],
+ [1, 0.100005, 0.45, 0.00000]]
+
+ figure = create_annotated_heatmap(z)
+ py.iplot(figure)
+ ```
+ """
+
+ # Avoiding mutables in the call signature
+ font_colors = font_colors if font_colors is not None else []
+ validate_annotated_heatmap(z, x, y, annotation_text)
+ annotations = _AnnotatedHeatmap(z, x, y, annotation_text,
+ colorscale, font_colors, reversescale,
+ **kwargs).make_annotations()
+
+ if x or y:
+ trace = dict(type='heatmap', z=z, x=x, y=y, colorscale=colorscale,
+ showscale=showscale, **kwargs)
+ layout = dict(annotations=annotations,
+ xaxis=dict(ticks='', dtick=1, side='top',
+ gridcolor='rgb(0, 0, 0)'),
+ yaxis=dict(ticks='', dtick=1, ticksuffix=' '))
+ else:
+ trace = dict(type='heatmap', z=z, colorscale=colorscale,
+ showscale=showscale, **kwargs)
+ layout = dict(annotations=annotations,
+ xaxis=dict(ticks='', side='top',
+ gridcolor='rgb(0, 0, 0)',
+ showticklabels=False),
+ yaxis=dict(ticks='', ticksuffix=' ',
+ showticklabels=False))
+
+ data = [trace]
+
+ return graph_objs.Figure(data=data, layout=layout)
+
+
+class _AnnotatedHeatmap(object):
+ """
+ Refer to TraceFactory.create_annotated_heatmap() for docstring
+ """
+ def __init__(self, z, x, y, annotation_text, colorscale,
+ font_colors, reversescale, **kwargs):
+
+ self.z = z
+ if x:
+ self.x = x
+ else:
+ self.x = range(len(z[0]))
+ if y:
+ self.y = y
+ else:
+ self.y = range(len(z))
+ if annotation_text is not None:
+ self.annotation_text = annotation_text
+ else:
+ self.annotation_text = self.z
+ self.colorscale = colorscale
+ self.reversescale = reversescale
+ self.font_colors = font_colors
+
+ def get_text_color(self):
+ """
+ Get font color for annotations.
+
+ The annotated heatmap can feature two text colors: min_text_color and
+ max_text_color. The min_text_color is applied to annotations for
+ heatmap values < (max_value - min_value)/2. The user can define these
+ two colors. Otherwise the colors are defined logically as black or
+ white depending on the heatmap's colorscale.
+
+ :rtype (string, string) min_text_color, max_text_color: text
+ color for annotations for heatmap values <
+ (max_value - min_value)/2 and text color for annotations for
+ heatmap values >= (max_value - min_value)/2
+ """
+ # Plotly colorscales ranging from a lighter shade to a darker shade
+ colorscales = ['Greys', 'Greens', 'Blues',
+ 'YIGnBu', 'YIOrRd', 'RdBu',
+ 'Picnic', 'Jet', 'Hot', 'Blackbody',
+ 'Earth', 'Electric', 'Viridis']
+ # Plotly colorscales ranging from a darker shade to a lighter shade
+ colorscales_reverse = ['Reds']
+ if self.font_colors:
+ min_text_color = self.font_colors[0]
+ max_text_color = self.font_colors[-1]
+ elif self.colorscale in colorscales and self.reversescale:
+ min_text_color = '#000000'
+ max_text_color = '#FFFFFF'
+ elif self.colorscale in colorscales:
+ min_text_color = '#FFFFFF'
+ max_text_color = '#000000'
+ elif self.colorscale in colorscales_reverse and self.reversescale:
+ min_text_color = '#FFFFFF'
+ max_text_color = '#000000'
+ elif self.colorscale in colorscales_reverse:
+ min_text_color = '#000000'
+ max_text_color = '#FFFFFF'
+ elif isinstance(self.colorscale, list):
+ if 'rgb' in self.colorscale[0][1]:
+ min_col = map(int,
+ self.colorscale[0][1].strip('rgb()').split(','))
+ max_col = map(int,
+ self.colorscale[-1][1].strip('rgb()').split(','))
+ elif '#' in self.colorscale[0][1]:
+ min_col = utils.hex_to_rgb(self.colorscale[0][1])
+ max_col = utils.hex_to_rgb(self.colorscale[-1][1])
+ else:
+ min_col = [255, 255, 255]
+ max_col = [255, 255, 255]
+
+ if (min_col[0]*0.299 + min_col[1]*0.587 + min_col[2]*0.114) > 186:
+ min_text_color = '#000000'
+ else:
+ min_text_color = '#FFFFFF'
+ if (max_col[0]*0.299 + max_col[1]*0.587 + max_col[2]*0.114) > 186:
+ max_text_color = '#000000'
+ else:
+ max_text_color = '#FFFFFF'
+ else:
+ min_text_color = '#000000'
+ max_text_color = '#000000'
+ return min_text_color, max_text_color
+
+ def get_z_mid(self):
+ """
+ Get the mid value of z matrix
+
+ :rtype (float) z_avg: average val from z matrix
+ """
+ if np and isinstance(self.z, np.ndarray):
+ z_min = np.amin(self.z)
+ z_max = np.amax(self.z)
+ else:
+ z_min = min(min(self.z))
+ z_max = max(max(self.z))
+ z_mid = (z_max+z_min) / 2
+ return z_mid
+
+ def make_annotations(self):
+ """
+ Get annotations for each cell of the heatmap with graph_objs.Annotation
+
+ :rtype (list[dict]) annotations: list of annotations for each cell of
+ the heatmap
+ """
+ min_text_color, max_text_color = _AnnotatedHeatmap.get_text_color(self)
+ z_mid = _AnnotatedHeatmap.get_z_mid(self)
+ annotations = []
+ for n, row in enumerate(self.z):
+ for m, val in enumerate(row):
+ font_color = min_text_color if val < z_mid else max_text_color
+ annotations.append(
+ graph_objs.Annotation(
+ text=str(self.annotation_text[n][m]),
+ x=self.x[m],
+ y=self.y[n],
+ xref='x1',
+ yref='y1',
+ font=dict(color=font_color),
+ showarrow=False))
+ return annotations
diff --git a/plotly/figure_factory/_candlestick.py b/plotly/figure_factory/_candlestick.py
new file mode 100644
index 00000000000..925b4c1a62b
--- /dev/null
+++ b/plotly/figure_factory/_candlestick.py
@@ -0,0 +1,294 @@
+from __future__ import absolute_import
+
+from plotly.figure_factory import utils
+from plotly.figure_factory._ohlc import (_DEFAULT_INCREASING_COLOR,
+ _DEFAULT_DECREASING_COLOR,
+ validate_ohlc)
+from plotly.graph_objs import graph_objs
+
+
+def make_increasing_candle(open, high, low, close, dates, **kwargs):
+ """
+ Makes boxplot trace for increasing candlesticks
+
+ _make_increasing_candle() and _make_decreasing_candle separate the
+ increasing traces from the decreasing traces so kwargs (such as
+ color) can be passed separately to increasing or decreasing traces
+ when direction is set to 'increasing' or 'decreasing' in
+ FigureFactory.create_candlestick()
+
+ :param (list) open: opening values
+ :param (list) high: high values
+ :param (list) low: low values
+ :param (list) close: closing values
+ :param (list) dates: list of datetime objects. Default: None
+ :param kwargs: kwargs to be passed to increasing trace via
+ plotly.graph_objs.Scatter.
+
+ :rtype (list) candle_incr_data: list of the box trace for
+ increasing candlesticks.
+ """
+ increase_x, increase_y = _Candlestick(
+ open, high, low, close, dates, **kwargs).get_candle_increase()
+
+ if 'line' in kwargs:
+ kwargs.setdefault('fillcolor', kwargs['line']['color'])
+ else:
+ kwargs.setdefault('fillcolor', _DEFAULT_INCREASING_COLOR)
+ if 'name' in kwargs:
+ kwargs.setdefault('showlegend', True)
+ else:
+ kwargs.setdefault('showlegend', False)
+ kwargs.setdefault('name', 'Increasing')
+ kwargs.setdefault('line', dict(color=_DEFAULT_INCREASING_COLOR))
+
+ candle_incr_data = dict(type='box',
+ x=increase_x,
+ y=increase_y,
+ whiskerwidth=0,
+ boxpoints=False,
+ **kwargs)
+
+ return [candle_incr_data]
+
+
+def make_decreasing_candle(open, high, low, close, dates, **kwargs):
+ """
+ Makes boxplot trace for decreasing candlesticks
+
+ :param (list) open: opening values
+ :param (list) high: high values
+ :param (list) low: low values
+ :param (list) close: closing values
+ :param (list) dates: list of datetime objects. Default: None
+ :param kwargs: kwargs to be passed to decreasing trace via
+ plotly.graph_objs.Scatter.
+
+ :rtype (list) candle_decr_data: list of the box trace for
+ decreasing candlesticks.
+ """
+
+ decrease_x, decrease_y = _Candlestick(
+ open, high, low, close, dates, **kwargs).get_candle_decrease()
+
+ if 'line' in kwargs:
+ kwargs.setdefault('fillcolor', kwargs['line']['color'])
+ else:
+ kwargs.setdefault('fillcolor', _DEFAULT_DECREASING_COLOR)
+ kwargs.setdefault('showlegend', False)
+ kwargs.setdefault('line', dict(color=_DEFAULT_DECREASING_COLOR))
+ kwargs.setdefault('name', 'Decreasing')
+
+ candle_decr_data = dict(type='box',
+ x=decrease_x,
+ y=decrease_y,
+ whiskerwidth=0,
+ boxpoints=False,
+ **kwargs)
+
+ return [candle_decr_data]
+
+
+def create_candlestick(open, high, low, close, dates=None, direction='both',
+ **kwargs):
+ """
+ BETA function that creates a candlestick chart
+
+ :param (list) open: opening values
+ :param (list) high: high values
+ :param (list) low: low values
+ :param (list) close: closing values
+ :param (list) dates: list of datetime objects. Default: None
+ :param (string) direction: direction can be 'increasing', 'decreasing',
+ or 'both'. When the direction is 'increasing', the returned figure
+ consists of all candlesticks where the close value is greater than
+ the corresponding open value, and when the direction is
+ 'decreasing', the returned figure consists of all candlesticks
+ where the close value is less than or equal to the corresponding
+ open value. When the direction is 'both', both increasing and
+ decreasing candlesticks are returned. Default: 'both'
+ :param kwargs: kwargs passed through plotly.graph_objs.Scatter.
+ These kwargs describe other attributes about the ohlc Scatter trace
+ such as the color or the legend name. For more information on valid
+ kwargs call help(plotly.graph_objs.Scatter)
+
+ :rtype (dict): returns a representation of candlestick chart figure.
+
+ Example 1: Simple candlestick chart from a Pandas DataFrame
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_candlestick
+ from datetime import datetime
+
+ import pandas.io.data as web
+
+ df = web.DataReader("aapl", 'yahoo', datetime(2007, 10, 1), datetime(2009, 4, 1))
+ fig = create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index)
+ py.plot(fig, filename='finance/aapl-candlestick', validate=False)
+ ```
+
+ Example 2: Add text and annotations to the candlestick chart
+ ```
+ fig = create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index)
+ # Update the fig - all options here: https://plot.ly/python/reference/#Layout
+ fig['layout'].update({
+ 'title': 'The Great Recession',
+ 'yaxis': {'title': 'AAPL Stock'},
+ 'shapes': [{
+ 'x0': '2007-12-01', 'x1': '2007-12-01',
+ 'y0': 0, 'y1': 1, 'xref': 'x', 'yref': 'paper',
+ 'line': {'color': 'rgb(30,30,30)', 'width': 1}
+ }],
+ 'annotations': [{
+ 'x': '2007-12-01', 'y': 0.05, 'xref': 'x', 'yref': 'paper',
+ 'showarrow': False, 'xanchor': 'left',
+ 'text': 'Official start of the recession'
+ }]
+ })
+ py.plot(fig, filename='finance/aapl-recession-candlestick', validate=False)
+ ```
+
+ Example 3: Customize the candlestick colors
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_candlestick
+ from plotly.graph_objs import Line, Marker
+ from datetime import datetime
+
+ import pandas.io.data as web
+
+ df = web.DataReader("aapl", 'yahoo', datetime(2008, 1, 1), datetime(2009, 4, 1))
+
+ # Make increasing candlesticks and customize their color and name
+ fig_increasing = create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index,
+ direction='increasing', name='AAPL',
+ marker=Marker(color='rgb(150, 200, 250)'),
+ line=Line(color='rgb(150, 200, 250)'))
+
+ # Make decreasing candlesticks and customize their color and name
+ fig_decreasing = create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index,
+ direction='decreasing',
+ marker=Marker(color='rgb(128, 128, 128)'),
+ line=Line(color='rgb(128, 128, 128)'))
+
+ # Initialize the figure
+ fig = fig_increasing
+
+ # Add decreasing data with .extend()
+ fig['data'].extend(fig_decreasing['data'])
+
+ py.iplot(fig, filename='finance/aapl-candlestick-custom', validate=False)
+ ```
+
+ Example 4: Candlestick chart with datetime objects
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_candlestick
+
+ from datetime import datetime
+
+ # Add data
+ open_data = [33.0, 33.3, 33.5, 33.0, 34.1]
+ high_data = [33.1, 33.3, 33.6, 33.2, 34.8]
+ low_data = [32.7, 32.7, 32.8, 32.6, 32.8]
+ close_data = [33.0, 32.9, 33.3, 33.1, 33.1]
+ dates = [datetime(year=2013, month=10, day=10),
+ datetime(year=2013, month=11, day=10),
+ datetime(year=2013, month=12, day=10),
+ datetime(year=2014, month=1, day=10),
+ datetime(year=2014, month=2, day=10)]
+
+ # Create ohlc
+ fig = create_candlestick(open_data, high_data,
+ low_data, close_data, dates=dates)
+
+ py.iplot(fig, filename='finance/simple-candlestick', validate=False)
+ ```
+ """
+ if dates is not None:
+ utils.validate_equal_length(open, high, low, close, dates)
+ else:
+ utils.validate_equal_length(open, high, low, close)
+ validate_ohlc(open, high, low, close, direction, **kwargs)
+
+ if direction is 'increasing':
+ candle_incr_data = make_increasing_candle(open, high, low, close,
+ dates, **kwargs)
+ data = candle_incr_data
+ elif direction is 'decreasing':
+ candle_decr_data = make_decreasing_candle(open, high, low, close,
+ dates, **kwargs)
+ data = candle_decr_data
+ else:
+ candle_incr_data = make_increasing_candle(open, high, low, close,
+ dates, **kwargs)
+ candle_decr_data = make_decreasing_candle(open, high, low, close,
+ dates, **kwargs)
+ data = candle_incr_data + candle_decr_data
+
+ layout = graph_objs.Layout()
+ return graph_objs.Figure(data=data, layout=layout)
+
+
+class _Candlestick(object):
+ """
+ Refer to FigureFactory.create_candlestick() for docstring.
+ """
+ def __init__(self, open, high, low, close, dates, **kwargs):
+ self.open = open
+ self.high = high
+ self.low = low
+ self.close = close
+ if dates is not None:
+ self.x = dates
+ else:
+ self.x = [x for x in range(len(self.open))]
+ self.get_candle_increase()
+
+ def get_candle_increase(self):
+ """
+ Separate increasing data from decreasing data.
+
+ The data is increasing when close value > open value
+ and decreasing when the close value <= open value.
+ """
+ increase_y = []
+ increase_x = []
+ for index in range(len(self.open)):
+ if self.close[index] > self.open[index]:
+ increase_y.append(self.low[index])
+ increase_y.append(self.open[index])
+ increase_y.append(self.close[index])
+ increase_y.append(self.close[index])
+ increase_y.append(self.close[index])
+ increase_y.append(self.high[index])
+ increase_x.append(self.x[index])
+
+ increase_x = [[x, x, x, x, x, x] for x in increase_x]
+ increase_x = utils.flatten(increase_x)
+
+ return increase_x, increase_y
+
+ def get_candle_decrease(self):
+ """
+ Separate increasing data from decreasing data.
+
+ The data is increasing when close value > open value
+ and decreasing when the close value <= open value.
+ """
+ decrease_y = []
+ decrease_x = []
+ for index in range(len(self.open)):
+ if self.close[index] <= self.open[index]:
+ decrease_y.append(self.low[index])
+ decrease_y.append(self.open[index])
+ decrease_y.append(self.close[index])
+ decrease_y.append(self.close[index])
+ decrease_y.append(self.close[index])
+ decrease_y.append(self.high[index])
+ decrease_x.append(self.x[index])
+
+ decrease_x = [[x, x, x, x, x, x] for x in decrease_x]
+ decrease_x = utils.flatten(decrease_x)
+
+ return decrease_x, decrease_y
diff --git a/plotly/figure_factory/_dendrogram.py b/plotly/figure_factory/_dendrogram.py
new file mode 100644
index 00000000000..26e6692290d
--- /dev/null
+++ b/plotly/figure_factory/_dendrogram.py
@@ -0,0 +1,304 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import absolute_import
+
+from collections import OrderedDict
+
+from plotly import exceptions, optional_imports
+from plotly.graph_objs import graph_objs
+
+# Optional imports, may be None for users that only use our core functionality.
+np = optional_imports.get_module('numpy')
+scp = optional_imports.get_module('scipy')
+sch = optional_imports.get_module('scipy.cluster.hierarchy')
+scs = optional_imports.get_module('scipy.spatial')
+
+
+def create_dendrogram(X, orientation="bottom", labels=None,
+ colorscale=None, distfun=None,
+ linkagefun=lambda x: sch.linkage(x, 'complete')):
+ """
+ BETA function that returns a dendrogram Plotly figure object.
+
+ :param (ndarray) X: Matrix of observations as array of arrays
+ :param (str) orientation: 'top', 'right', 'bottom', or 'left'
+ :param (list) labels: List of axis category labels(observation labels)
+ :param (list) colorscale: Optional colorscale for dendrogram tree
+ :param (function) distfun: Function to compute the pairwise distance from
+ the observations
+ :param (function) linkagefun: Function to compute the linkage matrix from
+ the pairwise distances
+
+ clusters
+
+ Example 1: Simple bottom oriented dendrogram
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_dendrogram
+
+ import numpy as np
+
+ X = np.random.rand(10,10)
+ dendro = create_dendrogram(X)
+ plot_url = py.plot(dendro, filename='simple-dendrogram')
+
+ ```
+
+ Example 2: Dendrogram to put on the left of the heatmap
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_dendrogram
+
+ import numpy as np
+
+ X = np.random.rand(5,5)
+ names = ['Jack', 'Oxana', 'John', 'Chelsea', 'Mark']
+ dendro = create_dendrogram(X, orientation='right', labels=names)
+ dendro['layout'].update({'width':700, 'height':500})
+
+ py.iplot(dendro, filename='vertical-dendrogram')
+ ```
+
+ Example 3: Dendrogram with Pandas
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_dendrogram
+
+ import numpy as np
+ import pandas as pd
+
+ Index= ['A','B','C','D','E','F','G','H','I','J']
+ df = pd.DataFrame(abs(np.random.randn(10, 10)), index=Index)
+ fig = create_dendrogram(df, labels=Index)
+ url = py.plot(fig, filename='pandas-dendrogram')
+ ```
+ """
+ if not scp or not scs or not sch:
+ raise ImportError("FigureFactory.create_dendrogram requires scipy, \
+ scipy.spatial and scipy.hierarchy")
+
+ s = X.shape
+ if len(s) != 2:
+ exceptions.PlotlyError("X should be 2-dimensional array.")
+
+ if distfun is None:
+ distfun = scs.distance.pdist
+
+ dendrogram = _Dendrogram(X, orientation, labels, colorscale,
+ distfun=distfun, linkagefun=linkagefun)
+
+ return {'layout': dendrogram.layout,
+ 'data': dendrogram.data}
+
+
+class _Dendrogram(object):
+ """Refer to FigureFactory.create_dendrogram() for docstring."""
+
+ def __init__(self, X, orientation='bottom', labels=None, colorscale=None,
+ width="100%", height="100%", xaxis='xaxis', yaxis='yaxis',
+ distfun=None,
+ linkagefun=lambda x: sch.linkage(x, 'complete')):
+ self.orientation = orientation
+ self.labels = labels
+ self.xaxis = xaxis
+ self.yaxis = yaxis
+ self.data = []
+ self.leaves = []
+ self.sign = {self.xaxis: 1, self.yaxis: 1}
+ self.layout = {self.xaxis: {}, self.yaxis: {}}
+
+ if self.orientation in ['left', 'bottom']:
+ self.sign[self.xaxis] = 1
+ else:
+ self.sign[self.xaxis] = -1
+
+ if self.orientation in ['right', 'bottom']:
+ self.sign[self.yaxis] = 1
+ else:
+ self.sign[self.yaxis] = -1
+
+ if distfun is None:
+ distfun = scs.distance.pdist
+
+ (dd_traces, xvals, yvals,
+ ordered_labels, leaves) = self.get_dendrogram_traces(X, colorscale,
+ distfun,
+ linkagefun)
+
+ self.labels = ordered_labels
+ self.leaves = leaves
+ yvals_flat = yvals.flatten()
+ xvals_flat = xvals.flatten()
+
+ self.zero_vals = []
+
+ for i in range(len(yvals_flat)):
+ if yvals_flat[i] == 0.0 and xvals_flat[i] not in self.zero_vals:
+ self.zero_vals.append(xvals_flat[i])
+
+ self.zero_vals.sort()
+
+ self.layout = self.set_figure_layout(width, height)
+ self.data = graph_objs.Data(dd_traces)
+
+ def get_color_dict(self, colorscale):
+ """
+ Returns colorscale used for dendrogram tree clusters.
+
+ :param (list) colorscale: Colors to use for the plot in rgb format.
+ :rtype (dict): A dict of default colors mapped to the user colorscale.
+
+ """
+
+ # These are the color codes returned for dendrograms
+ # We're replacing them with nicer colors
+ d = {'r': 'red',
+ 'g': 'green',
+ 'b': 'blue',
+ 'c': 'cyan',
+ 'm': 'magenta',
+ 'y': 'yellow',
+ 'k': 'black',
+ 'w': 'white'}
+ default_colors = OrderedDict(sorted(d.items(), key=lambda t: t[0]))
+
+ if colorscale is None:
+ colorscale = [
+ 'rgb(0,116,217)', # blue
+ 'rgb(35,205,205)', # cyan
+ 'rgb(61,153,112)', # green
+ 'rgb(40,35,35)', # black
+ 'rgb(133,20,75)', # magenta
+ 'rgb(255,65,54)', # red
+ 'rgb(255,255,255)', # white
+ 'rgb(255,220,0)'] # yellow
+
+ for i in range(len(default_colors.keys())):
+ k = list(default_colors.keys())[i] # PY3 won't index keys
+ if i < len(colorscale):
+ default_colors[k] = colorscale[i]
+
+ return default_colors
+
+ def set_axis_layout(self, axis_key):
+ """
+ Sets and returns default axis object for dendrogram figure.
+
+ :param (str) axis_key: E.g., 'xaxis', 'xaxis1', 'yaxis', yaxis1', etc.
+ :rtype (dict): An axis_key dictionary with set parameters.
+
+ """
+ axis_defaults = {
+ 'type': 'linear',
+ 'ticks': 'outside',
+ 'mirror': 'allticks',
+ 'rangemode': 'tozero',
+ 'showticklabels': True,
+ 'zeroline': False,
+ 'showgrid': False,
+ 'showline': True,
+ }
+
+ if len(self.labels) != 0:
+ axis_key_labels = self.xaxis
+ if self.orientation in ['left', 'right']:
+ axis_key_labels = self.yaxis
+ if axis_key_labels not in self.layout:
+ self.layout[axis_key_labels] = {}
+ self.layout[axis_key_labels]['tickvals'] = \
+ [zv*self.sign[axis_key] for zv in self.zero_vals]
+ self.layout[axis_key_labels]['ticktext'] = self.labels
+ self.layout[axis_key_labels]['tickmode'] = 'array'
+
+ self.layout[axis_key].update(axis_defaults)
+
+ return self.layout[axis_key]
+
+ def set_figure_layout(self, width, height):
+ """
+ Sets and returns default layout object for dendrogram figure.
+
+ """
+ self.layout.update({
+ 'showlegend': False,
+ 'autosize': False,
+ 'hovermode': 'closest',
+ 'width': width,
+ 'height': height
+ })
+
+ self.set_axis_layout(self.xaxis)
+ self.set_axis_layout(self.yaxis)
+
+ return self.layout
+
+ def get_dendrogram_traces(self, X, colorscale, distfun, linkagefun):
+ """
+ Calculates all the elements needed for plotting a dendrogram.
+
+ :param (ndarray) X: Matrix of observations as array of arrays
+ :param (list) colorscale: Color scale for dendrogram tree clusters
+ :param (function) distfun: Function to compute the pairwise distance
+ from the observations
+ :param (function) linkagefun: Function to compute the linkage matrix
+ from the pairwise distances
+ :rtype (tuple): Contains all the traces in the following order:
+ (a) trace_list: List of Plotly trace objects for dendrogram tree
+ (b) icoord: All X points of the dendrogram tree as array of arrays
+ with length 4
+ (c) dcoord: All Y points of the dendrogram tree as array of arrays
+ with length 4
+ (d) ordered_labels: leaf labels in the order they are going to
+ appear on the plot
+ (e) P['leaves']: left-to-right traversal of the leaves
+
+ """
+ d = distfun(X)
+ Z = linkagefun(d)
+ P = sch.dendrogram(Z, orientation=self.orientation,
+ labels=self.labels, no_plot=True)
+
+ icoord = scp.array(P['icoord'])
+ dcoord = scp.array(P['dcoord'])
+ ordered_labels = scp.array(P['ivl'])
+ color_list = scp.array(P['color_list'])
+ colors = self.get_color_dict(colorscale)
+
+ trace_list = []
+
+ for i in range(len(icoord)):
+ # xs and ys are arrays of 4 points that make up the '∩' shapes
+ # of the dendrogram tree
+ if self.orientation in ['top', 'bottom']:
+ xs = icoord[i]
+ else:
+ xs = dcoord[i]
+
+ if self.orientation in ['top', 'bottom']:
+ ys = dcoord[i]
+ else:
+ ys = icoord[i]
+ color_key = color_list[i]
+ trace = graph_objs.Scatter(
+ x=np.multiply(self.sign[self.xaxis], xs),
+ y=np.multiply(self.sign[self.yaxis], ys),
+ mode='lines',
+ marker=graph_objs.Marker(color=colors[color_key])
+ )
+
+ try:
+ x_index = int(self.xaxis[-1])
+ except ValueError:
+ x_index = ''
+
+ try:
+ y_index = int(self.yaxis[-1])
+ except ValueError:
+ y_index = ''
+
+ trace['xaxis'] = 'x' + x_index
+ trace['yaxis'] = 'y' + y_index
+
+ trace_list.append(trace)
+
+ return trace_list, icoord, dcoord, ordered_labels, P['leaves']
diff --git a/plotly/figure_factory/_distplot.py b/plotly/figure_factory/_distplot.py
new file mode 100644
index 00000000000..d2b77fb0873
--- /dev/null
+++ b/plotly/figure_factory/_distplot.py
@@ -0,0 +1,390 @@
+from __future__ import absolute_import
+
+from plotly import exceptions, optional_imports
+from plotly.figure_factory import utils
+from plotly.graph_objs import graph_objs
+
+# Optional imports, may be None for users that only use our core functionality.
+np = optional_imports.get_module('numpy')
+pd = optional_imports.get_module('pandas')
+scipy = optional_imports.get_module('scipy')
+scipy_stats = optional_imports.get_module('scipy.stats')
+
+
+DEFAULT_HISTNORM = 'probability density'
+ALTERNATIVE_HISTNORM = 'probability'
+
+
+def validate_distplot(hist_data, curve_type):
+ """
+ Distplot-specific validations
+
+ :raises: (PlotlyError) If hist_data is not a list of lists
+ :raises: (PlotlyError) If curve_type is not valid (i.e. not 'kde' or
+ 'normal').
+ """
+ hist_data_types = (list,)
+ if np:
+ hist_data_types += (np.ndarray,)
+ if pd:
+ hist_data_types += (pd.core.series.Series,)
+
+ if not isinstance(hist_data[0], hist_data_types):
+ raise exceptions.PlotlyError("Oops, this function was written "
+ "to handle multiple datasets, if "
+ "you want to plot just one, make "
+ "sure your hist_data variable is "
+ "still a list of lists, i.e. x = "
+ "[1, 2, 3] -> x = [[1, 2, 3]]")
+
+ curve_opts = ('kde', 'normal')
+ if curve_type not in curve_opts:
+ raise exceptions.PlotlyError("curve_type must be defined as "
+ "'kde' or 'normal'")
+
+ if not scipy:
+ raise ImportError("FigureFactory.create_distplot requires scipy")
+
+
+def create_distplot(hist_data, group_labels, bin_size=1., curve_type='kde',
+ colors=None, rug_text=None, histnorm=DEFAULT_HISTNORM,
+ show_hist=True, show_curve=True, show_rug=True):
+ """
+ BETA function that creates a distplot similar to seaborn.distplot
+
+ The distplot can be composed of all or any combination of the following
+ 3 components: (1) histogram, (2) curve: (a) kernel density estimation
+ or (b) normal curve, and (3) rug plot. Additionally, multiple distplots
+ (from multiple datasets) can be created in the same plot.
+
+ :param (list[list]) hist_data: Use list of lists to plot multiple data
+ sets on the same plot.
+ :param (list[str]) group_labels: Names for each data set.
+ :param (list[float]|float) bin_size: Size of histogram bins.
+ Default = 1.
+ :param (str) curve_type: 'kde' or 'normal'. Default = 'kde'
+ :param (str) histnorm: 'probability density' or 'probability'
+ Default = 'probability density'
+ :param (bool) show_hist: Add histogram to distplot? Default = True
+ :param (bool) show_curve: Add curve to distplot? Default = True
+ :param (bool) show_rug: Add rug to distplot? Default = True
+ :param (list[str]) colors: Colors for traces.
+ :param (list[list]) rug_text: Hovertext values for rug_plot,
+ :return (dict): Representation of a distplot figure.
+
+ Example 1: Simple distplot of 1 data set
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_distplot
+
+ hist_data = [[1.1, 1.1, 2.5, 3.0, 3.5,
+ 3.5, 4.1, 4.4, 4.5, 4.5,
+ 5.0, 5.0, 5.2, 5.5, 5.5,
+ 5.5, 5.5, 5.5, 6.1, 7.0]]
+
+ group_labels = ['distplot example']
+
+ fig = create_distplot(hist_data, group_labels)
+
+ url = py.plot(fig, filename='Simple distplot', validate=False)
+ ```
+
+ Example 2: Two data sets and added rug text
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_distplot
+
+ # Add histogram data
+ hist1_x = [0.8, 1.2, 0.2, 0.6, 1.6,
+ -0.9, -0.07, 1.95, 0.9, -0.2,
+ -0.5, 0.3, 0.4, -0.37, 0.6]
+ hist2_x = [0.8, 1.5, 1.5, 0.6, 0.59,
+ 1.0, 0.8, 1.7, 0.5, 0.8,
+ -0.3, 1.2, 0.56, 0.3, 2.2]
+
+ # Group data together
+ hist_data = [hist1_x, hist2_x]
+
+ group_labels = ['2012', '2013']
+
+ # Add text
+ rug_text_1 = ['a1', 'b1', 'c1', 'd1', 'e1',
+ 'f1', 'g1', 'h1', 'i1', 'j1',
+ 'k1', 'l1', 'm1', 'n1', 'o1']
+
+ rug_text_2 = ['a2', 'b2', 'c2', 'd2', 'e2',
+ 'f2', 'g2', 'h2', 'i2', 'j2',
+ 'k2', 'l2', 'm2', 'n2', 'o2']
+
+ # Group text together
+ rug_text_all = [rug_text_1, rug_text_2]
+
+ # Create distplot
+ fig = create_distplot(
+ hist_data, group_labels, rug_text=rug_text_all, bin_size=.2)
+
+ # Add title
+ fig['layout'].update(title='Dist Plot')
+
+ # Plot!
+ url = py.plot(fig, filename='Distplot with rug text', validate=False)
+ ```
+
+ Example 3: Plot with normal curve and hide rug plot
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_distplot
+ import numpy as np
+
+ x1 = np.random.randn(190)
+ x2 = np.random.randn(200)+1
+ x3 = np.random.randn(200)-1
+ x4 = np.random.randn(210)+2
+
+ hist_data = [x1, x2, x3, x4]
+ group_labels = ['2012', '2013', '2014', '2015']
+
+ fig = create_distplot(
+ hist_data, group_labels, curve_type='normal',
+ show_rug=False, bin_size=.4)
+
+ url = py.plot(fig, filename='hist and normal curve', validate=False)
+
+ Example 4: Distplot with Pandas
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_distplot
+ import numpy as np
+ import pandas as pd
+
+ df = pd.DataFrame({'2012': np.random.randn(200),
+ '2013': np.random.randn(200)+1})
+ py.iplot(create_distplot([df[c] for c in df.columns], df.columns),
+ filename='examples/distplot with pandas',
+ validate=False)
+ ```
+ """
+ if colors is None:
+ colors = []
+ if rug_text is None:
+ rug_text = []
+
+ validate_distplot(hist_data, curve_type)
+ utils.validate_equal_length(hist_data, group_labels)
+
+ if isinstance(bin_size, (float, int)):
+ bin_size = [bin_size] * len(hist_data)
+
+ hist = _Distplot(
+ hist_data, histnorm, group_labels, bin_size,
+ curve_type, colors, rug_text,
+ show_hist, show_curve).make_hist()
+
+ if curve_type == 'normal':
+ curve = _Distplot(
+ hist_data, histnorm, group_labels, bin_size,
+ curve_type, colors, rug_text,
+ show_hist, show_curve).make_normal()
+ else:
+ curve = _Distplot(
+ hist_data, histnorm, group_labels, bin_size,
+ curve_type, colors, rug_text,
+ show_hist, show_curve).make_kde()
+
+ rug = _Distplot(
+ hist_data, histnorm, group_labels, bin_size,
+ curve_type, colors, rug_text,
+ show_hist, show_curve).make_rug()
+
+ data = []
+ if show_hist:
+ data.append(hist)
+ if show_curve:
+ data.append(curve)
+ if show_rug:
+ data.append(rug)
+ layout = graph_objs.Layout(
+ barmode='overlay',
+ hovermode='closest',
+ legend=dict(traceorder='reversed'),
+ xaxis1=dict(domain=[0.0, 1.0],
+ anchor='y2',
+ zeroline=False),
+ yaxis1=dict(domain=[0.35, 1],
+ anchor='free',
+ position=0.0),
+ yaxis2=dict(domain=[0, 0.25],
+ anchor='x1',
+ dtick=1,
+ showticklabels=False))
+ else:
+ layout = graph_objs.Layout(
+ barmode='overlay',
+ hovermode='closest',
+ legend=dict(traceorder='reversed'),
+ xaxis1=dict(domain=[0.0, 1.0],
+ anchor='y2',
+ zeroline=False),
+ yaxis1=dict(domain=[0., 1],
+ anchor='free',
+ position=0.0))
+
+ data = sum(data, [])
+ return graph_objs.Figure(data=data, layout=layout)
+
+
+class _Distplot(object):
+ """
+ Refer to TraceFactory.create_distplot() for docstring
+ """
+ def __init__(self, hist_data, histnorm, group_labels,
+ bin_size, curve_type, colors,
+ rug_text, show_hist, show_curve):
+ self.hist_data = hist_data
+ self.histnorm = histnorm
+ self.group_labels = group_labels
+ self.bin_size = bin_size
+ self.show_hist = show_hist
+ self.show_curve = show_curve
+ self.trace_number = len(hist_data)
+ if rug_text:
+ self.rug_text = rug_text
+ else:
+ self.rug_text = [None] * self.trace_number
+
+ self.start = []
+ self.end = []
+ if colors:
+ self.colors = colors
+ else:
+ self.colors = [
+ "rgb(31, 119, 180)", "rgb(255, 127, 14)",
+ "rgb(44, 160, 44)", "rgb(214, 39, 40)",
+ "rgb(148, 103, 189)", "rgb(140, 86, 75)",
+ "rgb(227, 119, 194)", "rgb(127, 127, 127)",
+ "rgb(188, 189, 34)", "rgb(23, 190, 207)"]
+ self.curve_x = [None] * self.trace_number
+ self.curve_y = [None] * self.trace_number
+
+ for trace in self.hist_data:
+ self.start.append(min(trace) * 1.)
+ self.end.append(max(trace) * 1.)
+
+ def make_hist(self):
+ """
+ Makes the histogram(s) for FigureFactory.create_distplot().
+
+ :rtype (list) hist: list of histogram representations
+ """
+ hist = [None] * self.trace_number
+
+ for index in range(self.trace_number):
+ hist[index] = dict(type='histogram',
+ x=self.hist_data[index],
+ xaxis='x1',
+ yaxis='y1',
+ histnorm=self.histnorm,
+ name=self.group_labels[index],
+ legendgroup=self.group_labels[index],
+ marker=dict(color=self.colors[index]),
+ autobinx=False,
+ xbins=dict(start=self.start[index],
+ end=self.end[index],
+ size=self.bin_size[index]),
+ opacity=.7)
+ return hist
+
+ def make_kde(self):
+ """
+ Makes the kernel density estimation(s) for create_distplot().
+
+ This is called when curve_type = 'kde' in create_distplot().
+
+ :rtype (list) curve: list of kde representations
+ """
+ curve = [None] * self.trace_number
+ for index in range(self.trace_number):
+ self.curve_x[index] = [self.start[index] +
+ x * (self.end[index] - self.start[index])
+ / 500 for x in range(500)]
+ self.curve_y[index] = (scipy_stats.gaussian_kde
+ (self.hist_data[index])
+ (self.curve_x[index]))
+
+ if self.histnorm == ALTERNATIVE_HISTNORM:
+ self.curve_y[index] *= self.bin_size[index]
+
+ for index in range(self.trace_number):
+ curve[index] = dict(type='scatter',
+ x=self.curve_x[index],
+ y=self.curve_y[index],
+ xaxis='x1',
+ yaxis='y1',
+ mode='lines',
+ name=self.group_labels[index],
+ legendgroup=self.group_labels[index],
+ showlegend=False if self.show_hist else True,
+ marker=dict(color=self.colors[index]))
+ return curve
+
+ def make_normal(self):
+ """
+ Makes the normal curve(s) for create_distplot().
+
+ This is called when curve_type = 'normal' in create_distplot().
+
+ :rtype (list) curve: list of normal curve representations
+ """
+ curve = [None] * self.trace_number
+ mean = [None] * self.trace_number
+ sd = [None] * self.trace_number
+
+ for index in range(self.trace_number):
+ mean[index], sd[index] = (scipy_stats.norm.fit
+ (self.hist_data[index]))
+ self.curve_x[index] = [self.start[index] +
+ x * (self.end[index] - self.start[index])
+ / 500 for x in range(500)]
+ self.curve_y[index] = scipy_stats.norm.pdf(
+ self.curve_x[index], loc=mean[index], scale=sd[index])
+
+ if self.histnorm == ALTERNATIVE_HISTNORM:
+ self.curve_y[index] *= self.bin_size[index]
+
+ for index in range(self.trace_number):
+ curve[index] = dict(type='scatter',
+ x=self.curve_x[index],
+ y=self.curve_y[index],
+ xaxis='x1',
+ yaxis='y1',
+ mode='lines',
+ name=self.group_labels[index],
+ legendgroup=self.group_labels[index],
+ showlegend=False if self.show_hist else True,
+ marker=dict(color=self.colors[index]))
+ return curve
+
+ def make_rug(self):
+ """
+ Makes the rug plot(s) for create_distplot().
+
+ :rtype (list) rug: list of rug plot representations
+ """
+ rug = [None] * self.trace_number
+ for index in range(self.trace_number):
+
+ rug[index] = dict(type='scatter',
+ x=self.hist_data[index],
+ y=([self.group_labels[index]] *
+ len(self.hist_data[index])),
+ xaxis='x1',
+ yaxis='y2',
+ mode='markers',
+ name=self.group_labels[index],
+ legendgroup=self.group_labels[index],
+ showlegend=(False if self.show_hist or
+ self.show_curve else True),
+ text=self.rug_text[index],
+ marker=dict(color=self.colors[index],
+ symbol='line-ns-open'))
+ return rug
diff --git a/plotly/figure_factory/_gantt.py b/plotly/figure_factory/_gantt.py
new file mode 100644
index 00000000000..c78434f0381
--- /dev/null
+++ b/plotly/figure_factory/_gantt.py
@@ -0,0 +1,778 @@
+from __future__ import absolute_import
+
+from numbers import Number
+
+from plotly import exceptions, optional_imports
+from plotly.figure_factory import utils
+
+pd = optional_imports.get_module('pandas')
+
+REQUIRED_GANTT_KEYS = ['Task', 'Start', 'Finish']
+
+
+def validate_gantt(df):
+ """
+ Validates the inputted dataframe or list
+ """
+ if pd and isinstance(df, pd.core.frame.DataFrame):
+ # validate that df has all the required keys
+ for key in REQUIRED_GANTT_KEYS:
+ if key not in df:
+ raise exceptions.PlotlyError(
+ "The columns in your dataframe must include the "
+ "following keys: {0}".format(
+ ', '.join(REQUIRED_GANTT_KEYS))
+ )
+
+ num_of_rows = len(df.index)
+ chart = []
+ for index in range(num_of_rows):
+ task_dict = {}
+ for key in df:
+ task_dict[key] = df.ix[index][key]
+ chart.append(task_dict)
+
+ return chart
+
+ # validate if df is a list
+ if not isinstance(df, list):
+ raise exceptions.PlotlyError("You must input either a dataframe "
+ "or a list of dictionaries.")
+
+ # validate if df is empty
+ if len(df) <= 0:
+ raise exceptions.PlotlyError("Your list is empty. It must contain "
+ "at least one dictionary.")
+ if not isinstance(df[0], dict):
+ raise exceptions.PlotlyError("Your list must only "
+ "include dictionaries.")
+ return df
+
+
+def gantt(chart, colors, title, bar_width, showgrid_x, showgrid_y, height,
+ width, tasks=None, task_names=None, data=None, group_tasks=False):
+ """
+ Refer to create_gantt() for docstring
+ """
+ if tasks is None:
+ tasks = []
+ if task_names is None:
+ task_names = []
+ if data is None:
+ data = []
+
+ for index in range(len(chart)):
+ task = dict(x0=chart[index]['Start'],
+ x1=chart[index]['Finish'],
+ name=chart[index]['Task'])
+ if 'Description' in chart[index]:
+ task['description'] = chart[index]['Description']
+ tasks.append(task)
+
+ shape_template = {
+ 'type': 'rect',
+ 'xref': 'x',
+ 'yref': 'y',
+ 'opacity': 1,
+ 'line': {
+ 'width': 0,
+ }
+ }
+ # create the list of task names
+ for index in range(len(tasks)):
+ tn = tasks[index]['name']
+ # Is added to task_names if group_tasks is set to False,
+ # or if the option is used (True) it only adds them if the
+ # name is not already in the list
+ if not group_tasks or tn not in task_names:
+ task_names.append(tn)
+ # Guarantees that for grouped tasks the tasks that are inserted first
+ # are shown at the top
+ if group_tasks:
+ task_names.reverse()
+
+ color_index = 0
+ for index in range(len(tasks)):
+ tn = tasks[index]['name']
+ del tasks[index]['name']
+ tasks[index].update(shape_template)
+
+ # If group_tasks is True, all tasks with the same name belong
+ # to the same row.
+ groupID = index
+ if group_tasks:
+ groupID = task_names.index(tn)
+ tasks[index]['y0'] = groupID - bar_width
+ tasks[index]['y1'] = groupID + bar_width
+
+ # check if colors need to be looped
+ if color_index >= len(colors):
+ color_index = 0
+ tasks[index]['fillcolor'] = colors[color_index]
+ # Add a line for hover text and autorange
+ entry = dict(
+ x=[tasks[index]['x0'], tasks[index]['x1']],
+ y=[groupID, groupID],
+ name='',
+ marker={'color': 'white'}
+ )
+ if "description" in tasks[index]:
+ entry['text'] = tasks[index]['description']
+ del tasks[index]['description']
+ data.append(entry)
+ color_index += 1
+
+ layout = dict(
+ title=title,
+ showlegend=False,
+ height=height,
+ width=width,
+ shapes=[],
+ hovermode='closest',
+ yaxis=dict(
+ showgrid=showgrid_y,
+ ticktext=task_names,
+ tickvals=list(range(len(task_names))),
+ range=[-1, len(task_names) + 1],
+ autorange=False,
+ zeroline=False,
+ ),
+ xaxis=dict(
+ showgrid=showgrid_x,
+ zeroline=False,
+ rangeselector=dict(
+ buttons=list([
+ dict(count=7,
+ label='1w',
+ step='day',
+ stepmode='backward'),
+ dict(count=1,
+ label='1m',
+ step='month',
+ stepmode='backward'),
+ dict(count=6,
+ label='6m',
+ step='month',
+ stepmode='backward'),
+ dict(count=1,
+ label='YTD',
+ step='year',
+ stepmode='todate'),
+ dict(count=1,
+ label='1y',
+ step='year',
+ stepmode='backward'),
+ dict(step='all')
+ ])
+ ),
+ type='date'
+ )
+ )
+ layout['shapes'] = tasks
+
+ fig = dict(data=data, layout=layout)
+ return fig
+
+
+def gantt_colorscale(chart, colors, title, index_col, show_colorbar, bar_width,
+ showgrid_x, showgrid_y, height, width, tasks=None,
+ task_names=None, data=None, group_tasks=False):
+ """
+ Refer to FigureFactory.create_gantt() for docstring
+ """
+ if tasks is None:
+ tasks = []
+ if task_names is None:
+ task_names = []
+ if data is None:
+ data = []
+ showlegend = False
+
+ for index in range(len(chart)):
+ task = dict(x0=chart[index]['Start'],
+ x1=chart[index]['Finish'],
+ name=chart[index]['Task'])
+ if 'Description' in chart[index]:
+ task['description'] = chart[index]['Description']
+ tasks.append(task)
+
+ shape_template = {
+ 'type': 'rect',
+ 'xref': 'x',
+ 'yref': 'y',
+ 'opacity': 1,
+ 'line': {
+ 'width': 0,
+ }
+ }
+
+ # compute the color for task based on indexing column
+ if isinstance(chart[0][index_col], Number):
+ # check that colors has at least 2 colors
+ if len(colors) < 2:
+ raise exceptions.PlotlyError(
+ "You must use at least 2 colors in 'colors' if you "
+ "are using a colorscale. However only the first two "
+ "colors given will be used for the lower and upper "
+ "bounds on the colormap."
+ )
+
+ # create the list of task names
+ for index in range(len(tasks)):
+ tn = tasks[index]['name']
+ # Is added to task_names if group_tasks is set to False,
+ # or if the option is used (True) it only adds them if the
+ # name is not already in the list
+ if not group_tasks or tn not in task_names:
+ task_names.append(tn)
+ # Guarantees that for grouped tasks the tasks that are inserted
+ # first are shown at the top
+ if group_tasks:
+ task_names.reverse()
+
+ for index in range(len(tasks)):
+ tn = tasks[index]['name']
+ del tasks[index]['name']
+ tasks[index].update(shape_template)
+
+ # If group_tasks is True, all tasks with the same name belong
+ # to the same row.
+ groupID = index
+ if group_tasks:
+ groupID = task_names.index(tn)
+ tasks[index]['y0'] = groupID - bar_width
+ tasks[index]['y1'] = groupID + bar_width
+
+ # unlabel color
+ colors = utils.color_parser(colors, utils.unlabel_rgb)
+ lowcolor = colors[0]
+ highcolor = colors[1]
+
+ intermed = (chart[index][index_col]) / 100.0
+ intermed_color = utils.find_intermediate_color(
+ lowcolor, highcolor, intermed
+ )
+ intermed_color = utils.color_parser(
+ intermed_color, utils.label_rgb
+ )
+ tasks[index]['fillcolor'] = intermed_color
+ # relabel colors with 'rgb'
+ colors = utils.color_parser(colors, utils.label_rgb)
+
+ # add a line for hover text and autorange
+ entry = dict(
+ x=[tasks[index]['x0'], tasks[index]['x1']],
+ y=[groupID, groupID],
+ name='',
+ marker={'color': 'white'}
+ )
+ if "description" in tasks[index]:
+ entry['text'] = tasks[index]['description']
+ del tasks[index]['description']
+ data.append(entry)
+
+ if show_colorbar is True:
+ # generate dummy data for colorscale visibility
+ data.append(
+ dict(
+ x=[tasks[index]['x0'], tasks[index]['x0']],
+ y=[index, index],
+ name='',
+ marker={'color': 'white',
+ 'colorscale': [[0, colors[0]], [1, colors[1]]],
+ 'showscale': True,
+ 'cmax': 100,
+ 'cmin': 0}
+ )
+ )
+
+ if isinstance(chart[0][index_col], str):
+ index_vals = []
+ for row in range(len(tasks)):
+ if chart[row][index_col] not in index_vals:
+ index_vals.append(chart[row][index_col])
+
+ index_vals.sort()
+
+ if len(colors) < len(index_vals):
+ raise exceptions.PlotlyError(
+ "Error. The number of colors in 'colors' must be no less "
+ "than the number of unique index values in your group "
+ "column."
+ )
+
+ # make a dictionary assignment to each index value
+ index_vals_dict = {}
+ # define color index
+ c_index = 0
+ for key in index_vals:
+ if c_index > len(colors) - 1:
+ c_index = 0
+ index_vals_dict[key] = colors[c_index]
+ c_index += 1
+
+ # create the list of task names
+ for index in range(len(tasks)):
+ tn = tasks[index]['name']
+ # Is added to task_names if group_tasks is set to False,
+ # or if the option is used (True) it only adds them if the
+ # name is not already in the list
+ if not group_tasks or tn not in task_names:
+ task_names.append(tn)
+ # Guarantees that for grouped tasks the tasks that are inserted
+ # first are shown at the top
+ if group_tasks:
+ task_names.reverse()
+
+ for index in range(len(tasks)):
+ tn = tasks[index]['name']
+ del tasks[index]['name']
+ tasks[index].update(shape_template)
+ # If group_tasks is True, all tasks with the same name belong
+ # to the same row.
+ groupID = index
+ if group_tasks:
+ groupID = task_names.index(tn)
+ tasks[index]['y0'] = groupID - bar_width
+ tasks[index]['y1'] = groupID + bar_width
+
+ tasks[index]['fillcolor'] = index_vals_dict[
+ chart[index][index_col]
+ ]
+
+ # add a line for hover text and autorange
+ entry = dict(
+ x=[tasks[index]['x0'], tasks[index]['x1']],
+ y=[groupID, groupID],
+ name='',
+ marker={'color': 'white'}
+ )
+ if "description" in tasks[index]:
+ entry['text'] = tasks[index]['description']
+ del tasks[index]['description']
+ data.append(entry)
+
+ if show_colorbar is True:
+ # generate dummy data to generate legend
+ showlegend = True
+ for k, index_value in enumerate(index_vals):
+ data.append(
+ dict(
+ x=[tasks[index]['x0'], tasks[index]['x0']],
+ y=[k, k],
+ showlegend=True,
+ name=str(index_value),
+ hoverinfo='none',
+ marker=dict(
+ color=colors[k],
+ size=1
+ )
+ )
+ )
+
+ layout = dict(
+ title=title,
+ showlegend=showlegend,
+ height=height,
+ width=width,
+ shapes=[],
+ hovermode='closest',
+ yaxis=dict(
+ showgrid=showgrid_y,
+ ticktext=task_names,
+ tickvals=list(range(len(task_names))),
+ range=[-1, len(task_names) + 1],
+ autorange=False,
+ zeroline=False,
+ ),
+ xaxis=dict(
+ showgrid=showgrid_x,
+ zeroline=False,
+ rangeselector=dict(
+ buttons=list([
+ dict(count=7,
+ label='1w',
+ step='day',
+ stepmode='backward'),
+ dict(count=1,
+ label='1m',
+ step='month',
+ stepmode='backward'),
+ dict(count=6,
+ label='6m',
+ step='month',
+ stepmode='backward'),
+ dict(count=1,
+ label='YTD',
+ step='year',
+ stepmode='todate'),
+ dict(count=1,
+ label='1y',
+ step='year',
+ stepmode='backward'),
+ dict(step='all')
+ ])
+ ),
+ type='date'
+ )
+ )
+ layout['shapes'] = tasks
+
+ fig = dict(data=data, layout=layout)
+ return fig
+
+
+def gantt_dict(chart, colors, title, index_col, show_colorbar, bar_width,
+ showgrid_x, showgrid_y, height, width, tasks=None,
+ task_names=None, data=None, group_tasks=False):
+ """
+ Refer to FigureFactory.create_gantt() for docstring
+ """
+ if tasks is None:
+ tasks = []
+ if task_names is None:
+ task_names = []
+ if data is None:
+ data = []
+ showlegend = False
+
+ for index in range(len(chart)):
+ task = dict(x0=chart[index]['Start'],
+ x1=chart[index]['Finish'],
+ name=chart[index]['Task'])
+ if 'Description' in chart[index]:
+ task['description'] = chart[index]['Description']
+ tasks.append(task)
+
+ shape_template = {
+ 'type': 'rect',
+ 'xref': 'x',
+ 'yref': 'y',
+ 'opacity': 1,
+ 'line': {
+ 'width': 0,
+ }
+ }
+
+ index_vals = []
+ for row in range(len(tasks)):
+ if chart[row][index_col] not in index_vals:
+ index_vals.append(chart[row][index_col])
+
+ index_vals.sort()
+
+ # verify each value in index column appears in colors dictionary
+ for key in index_vals:
+ if key not in colors:
+ raise exceptions.PlotlyError(
+ "If you are using colors as a dictionary, all of its "
+ "keys must be all the values in the index column."
+ )
+
+ # create the list of task names
+ for index in range(len(tasks)):
+ tn = tasks[index]['name']
+ # Is added to task_names if group_tasks is set to False,
+ # or if the option is used (True) it only adds them if the
+ # name is not already in the list
+ if not group_tasks or tn not in task_names:
+ task_names.append(tn)
+ # Guarantees that for grouped tasks the tasks that are inserted first
+ # are shown at the top
+ if group_tasks:
+ task_names.reverse()
+
+ for index in range(len(tasks)):
+ tn = tasks[index]['name']
+ del tasks[index]['name']
+ tasks[index].update(shape_template)
+
+ # If group_tasks is True, all tasks with the same name belong
+ # to the same row.
+ groupID = index
+ if group_tasks:
+ groupID = task_names.index(tn)
+ tasks[index]['y0'] = groupID - bar_width
+ tasks[index]['y1'] = groupID + bar_width
+
+ tasks[index]['fillcolor'] = colors[chart[index][index_col]]
+
+ # add a line for hover text and autorange
+ entry = dict(
+ x=[tasks[index]['x0'], tasks[index]['x1']],
+ y=[groupID, groupID],
+ name='',
+ marker={'color': 'white'}
+ )
+ if "description" in tasks[index]:
+ entry['text'] = tasks[index]['description']
+ del tasks[index]['description']
+ data.append(entry)
+
+ if show_colorbar is True:
+ # generate dummy data to generate legend
+ showlegend = True
+ for k, index_value in enumerate(index_vals):
+ data.append(
+ dict(
+ x=[tasks[index]['x0'], tasks[index]['x0']],
+ y=[k, k],
+ showlegend=True,
+ hoverinfo='none',
+ name=str(index_value),
+ marker=dict(
+ color=colors[index_value],
+ size=1
+ )
+ )
+ )
+
+ layout = dict(
+ title=title,
+ showlegend=showlegend,
+ height=height,
+ width=width,
+ shapes=[],
+ hovermode='closest',
+ yaxis=dict(
+ showgrid=showgrid_y,
+ ticktext=task_names,
+ tickvals=list(range(len(task_names))),
+ range=[-1, len(task_names) + 1],
+ autorange=False,
+ zeroline=False,
+ ),
+ xaxis=dict(
+ showgrid=showgrid_x,
+ zeroline=False,
+ rangeselector=dict(
+ buttons=list([
+ dict(count=7,
+ label='1w',
+ step='day',
+ stepmode='backward'),
+ dict(count=1,
+ label='1m',
+ step='month',
+ stepmode='backward'),
+ dict(count=6,
+ label='6m',
+ step='month',
+ stepmode='backward'),
+ dict(count=1,
+ label='YTD',
+ step='year',
+ stepmode='todate'),
+ dict(count=1,
+ label='1y',
+ step='year',
+ stepmode='backward'),
+ dict(step='all')
+ ])
+ ),
+ type='date'
+ )
+ )
+ layout['shapes'] = tasks
+
+ fig = dict(data=data, layout=layout)
+ return fig
+
+
+def create_gantt(df, colors=None, index_col=None, show_colorbar=False,
+ reverse_colors=False, title='Gantt Chart', bar_width=0.2,
+ showgrid_x=False, showgrid_y=False, height=600, width=900,
+ tasks=None, task_names=None, data=None, group_tasks=False):
+ """
+ Returns figure for a gantt chart
+
+ :param (array|list) df: input data for gantt chart. Must be either a
+ a dataframe or a list. If dataframe, the columns must include
+ 'Task', 'Start' and 'Finish'. Other columns can be included and
+ used for indexing. If a list, its elements must be dictionaries
+ with the same required column headers: 'Task', 'Start' and
+ 'Finish'.
+ :param (str|list|dict|tuple) colors: either a plotly scale name, an
+ rgb or hex color, a color tuple or a list of colors. An rgb color
+ is of the form 'rgb(x, y, z)' where x, y, z belong to the interval
+ [0, 255] and a color tuple is a tuple of the form (a, b, c) where
+ a, b and c belong to [0, 1]. If colors is a list, it must
+ contain the valid color types aforementioned as its members.
+ If a dictionary, all values of the indexing column must be keys in
+ colors.
+ :param (str|float) index_col: the column header (if df is a data
+ frame) that will function as the indexing column. If df is a list,
+ index_col must be one of the keys in all the items of df.
+ :param (bool) show_colorbar: determines if colorbar will be visible.
+ Only applies if values in the index column are numeric.
+ :param (bool) reverse_colors: reverses the order of selected colors
+ :param (str) title: the title of the chart
+ :param (float) bar_width: the width of the horizontal bars in the plot
+ :param (bool) showgrid_x: show/hide the x-axis grid
+ :param (bool) showgrid_y: show/hide the y-axis grid
+ :param (float) height: the height of the chart
+ :param (float) width: the width of the chart
+
+ Example 1: Simple Gantt Chart
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_gantt
+
+ # Make data for chart
+ df = [dict(Task="Job A", Start='2009-01-01', Finish='2009-02-30'),
+ dict(Task="Job B", Start='2009-03-05', Finish='2009-04-15'),
+ dict(Task="Job C", Start='2009-02-20', Finish='2009-05-30')]
+
+ # Create a figure
+ fig = create_gantt(df)
+
+ # Plot the data
+ py.iplot(fig, filename='Simple Gantt Chart', world_readable=True)
+ ```
+
+ Example 2: Index by Column with Numerical Entries
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_gantt
+
+ # Make data for chart
+ df = [dict(Task="Job A", Start='2009-01-01',
+ Finish='2009-02-30', Complete=10),
+ dict(Task="Job B", Start='2009-03-05',
+ Finish='2009-04-15', Complete=60),
+ dict(Task="Job C", Start='2009-02-20',
+ Finish='2009-05-30', Complete=95)]
+
+ # Create a figure with Plotly colorscale
+ fig = create_gantt(df, colors='Blues', index_col='Complete',
+ show_colorbar=True, bar_width=0.5,
+ showgrid_x=True, showgrid_y=True)
+
+ # Plot the data
+ py.iplot(fig, filename='Numerical Entries', world_readable=True)
+ ```
+
+ Example 3: Index by Column with String Entries
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_gantt
+
+ # Make data for chart
+ df = [dict(Task="Job A", Start='2009-01-01',
+ Finish='2009-02-30', Resource='Apple'),
+ dict(Task="Job B", Start='2009-03-05',
+ Finish='2009-04-15', Resource='Grape'),
+ dict(Task="Job C", Start='2009-02-20',
+ Finish='2009-05-30', Resource='Banana')]
+
+ # Create a figure with Plotly colorscale
+ fig = create_gantt(df, colors=['rgb(200, 50, 25)', (1, 0, 1), '#6c4774'],
+ index_col='Resource', reverse_colors=True,
+ show_colorbar=True)
+
+ # Plot the data
+ py.iplot(fig, filename='String Entries', world_readable=True)
+ ```
+
+ Example 4: Use a dictionary for colors
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_gantt
+
+ # Make data for chart
+ df = [dict(Task="Job A", Start='2009-01-01',
+ Finish='2009-02-30', Resource='Apple'),
+ dict(Task="Job B", Start='2009-03-05',
+ Finish='2009-04-15', Resource='Grape'),
+ dict(Task="Job C", Start='2009-02-20',
+ Finish='2009-05-30', Resource='Banana')]
+
+ # Make a dictionary of colors
+ colors = {'Apple': 'rgb(255, 0, 0)',
+ 'Grape': 'rgb(170, 14, 200)',
+ 'Banana': (1, 1, 0.2)}
+
+ # Create a figure with Plotly colorscale
+ fig = create_gantt(df, colors=colors, index_col='Resource',
+ show_colorbar=True)
+
+ # Plot the data
+ py.iplot(fig, filename='dictioanry colors', world_readable=True)
+ ```
+
+ Example 5: Use a pandas dataframe
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_gantt
+
+ import pandas as pd
+
+ # Make data as a dataframe
+ df = pd.DataFrame([['Run', '2010-01-01', '2011-02-02', 10],
+ ['Fast', '2011-01-01', '2012-06-05', 55],
+ ['Eat', '2012-01-05', '2013-07-05', 94]],
+ columns=['Task', 'Start', 'Finish', 'Complete'])
+
+ # Create a figure with Plotly colorscale
+ fig = create_gantt(df, colors='Blues', index_col='Complete',
+ show_colorbar=True, bar_width=0.5,
+ showgrid_x=True, showgrid_y=True)
+
+ # Plot the data
+ py.iplot(fig, filename='data with dataframe', world_readable=True)
+ ```
+ """
+ # validate gantt input data
+ chart = validate_gantt(df)
+
+ if index_col:
+ if index_col not in chart[0]:
+ raise exceptions.PlotlyError(
+ "In order to use an indexing column and assign colors to "
+ "the values of the index, you must choose an actual "
+ "column name in the dataframe or key if a list of "
+ "dictionaries is being used.")
+
+ # validate gantt index column
+ index_list = []
+ for dictionary in chart:
+ index_list.append(dictionary[index_col])
+ utils.validate_index(index_list)
+
+ # Validate colors
+ if isinstance(colors, dict):
+ colors = utils.validate_colors_dict(colors, 'rgb')
+ else:
+ colors = utils.validate_colors(colors, 'rgb')
+
+ if reverse_colors is True:
+ colors.reverse()
+
+ if not index_col:
+ if isinstance(colors, dict):
+ raise exceptions.PlotlyError(
+ "Error. You have set colors to a dictionary but have not "
+ "picked an index. An index is required if you are "
+ "assigning colors to particular values in a dictioanry."
+ )
+ fig = gantt(
+ chart, colors, title, bar_width, showgrid_x, showgrid_y,
+ height, width, tasks=None, task_names=None, data=None,
+ group_tasks=group_tasks
+ )
+ return fig
+ else:
+ if not isinstance(colors, dict):
+ fig = gantt_colorscale(
+ chart, colors, title, index_col, show_colorbar, bar_width,
+ showgrid_x, showgrid_y, height, width,
+ tasks=None, task_names=None, data=None, group_tasks=group_tasks
+ )
+ return fig
+ else:
+ fig = gantt_dict(
+ chart, colors, title, index_col, show_colorbar, bar_width,
+ showgrid_x, showgrid_y, height, width,
+ tasks=None, task_names=None, data=None, group_tasks=group_tasks
+ )
+ return fig
diff --git a/plotly/figure_factory/_ohlc.py b/plotly/figure_factory/_ohlc.py
new file mode 100644
index 00000000000..b5e84cd6d93
--- /dev/null
+++ b/plotly/figure_factory/_ohlc.py
@@ -0,0 +1,380 @@
+from __future__ import absolute_import
+
+from plotly import exceptions
+from plotly.graph_objs import graph_objs
+from plotly.figure_factory import utils
+
+
+# Default colours for finance charts
+_DEFAULT_INCREASING_COLOR = '#3D9970' # http://clrs.cc
+_DEFAULT_DECREASING_COLOR = '#FF4136'
+
+
+def validate_ohlc(open, high, low, close, direction, **kwargs):
+ """
+ ohlc and candlestick specific validations
+
+ Specifically, this checks that the high value is the greatest value and
+ the low value is the lowest value in each unit.
+
+ See FigureFactory.create_ohlc() or FigureFactory.create_candlestick()
+ for params
+
+ :raises: (PlotlyError) If the high value is not the greatest value in
+ each unit.
+ :raises: (PlotlyError) If the low value is not the lowest value in each
+ unit.
+ :raises: (PlotlyError) If direction is not 'increasing' or 'decreasing'
+ """
+ for lst in [open, low, close]:
+ for index in range(len(high)):
+ if high[index] < lst[index]:
+ raise exceptions.PlotlyError("Oops! Looks like some of "
+ "your high values are less "
+ "the corresponding open, "
+ "low, or close values. "
+ "Double check that your data "
+ "is entered in O-H-L-C order")
+
+ for lst in [open, high, close]:
+ for index in range(len(low)):
+ if low[index] > lst[index]:
+ raise exceptions.PlotlyError("Oops! Looks like some of "
+ "your low values are greater "
+ "than the corresponding high"
+ ", open, or close values. "
+ "Double check that your data "
+ "is entered in O-H-L-C order")
+
+ direction_opts = ('increasing', 'decreasing', 'both')
+ if direction not in direction_opts:
+ raise exceptions.PlotlyError("direction must be defined as "
+ "'increasing', 'decreasing', or "
+ "'both'")
+
+
+def make_increasing_ohlc(open, high, low, close, dates, **kwargs):
+ """
+ Makes increasing ohlc sticks
+
+ _make_increasing_ohlc() and _make_decreasing_ohlc separate the
+ increasing trace from the decreasing trace so kwargs (such as
+ color) can be passed separately to increasing or decreasing traces
+ when direction is set to 'increasing' or 'decreasing' in
+ FigureFactory.create_candlestick()
+
+ :param (list) open: opening values
+ :param (list) high: high values
+ :param (list) low: low values
+ :param (list) close: closing values
+ :param (list) dates: list of datetime objects. Default: None
+ :param kwargs: kwargs to be passed to increasing trace via
+ plotly.graph_objs.Scatter.
+
+ :rtype (trace) ohlc_incr_data: Scatter trace of all increasing ohlc
+ sticks.
+ """
+ (flat_increase_x,
+ flat_increase_y,
+ text_increase) = _OHLC(open, high, low, close, dates).get_increase()
+
+ if 'name' in kwargs:
+ showlegend = True
+ else:
+ kwargs.setdefault('name', 'Increasing')
+ showlegend = False
+
+ kwargs.setdefault('line', dict(color=_DEFAULT_INCREASING_COLOR,
+ width=1))
+ kwargs.setdefault('text', text_increase)
+
+ ohlc_incr = dict(type='scatter',
+ x=flat_increase_x,
+ y=flat_increase_y,
+ mode='lines',
+ showlegend=showlegend,
+ **kwargs)
+ return ohlc_incr
+
+
+def make_decreasing_ohlc(open, high, low, close, dates, **kwargs):
+ """
+ Makes decreasing ohlc sticks
+
+ :param (list) open: opening values
+ :param (list) high: high values
+ :param (list) low: low values
+ :param (list) close: closing values
+ :param (list) dates: list of datetime objects. Default: None
+ :param kwargs: kwargs to be passed to increasing trace via
+ plotly.graph_objs.Scatter.
+
+ :rtype (trace) ohlc_decr_data: Scatter trace of all decreasing ohlc
+ sticks.
+ """
+ (flat_decrease_x,
+ flat_decrease_y,
+ text_decrease) = _OHLC(open, high, low, close, dates).get_decrease()
+
+ kwargs.setdefault('line', dict(color=_DEFAULT_DECREASING_COLOR,
+ width=1))
+ kwargs.setdefault('text', text_decrease)
+ kwargs.setdefault('showlegend', False)
+ kwargs.setdefault('name', 'Decreasing')
+
+ ohlc_decr = dict(type='scatter',
+ x=flat_decrease_x,
+ y=flat_decrease_y,
+ mode='lines',
+ **kwargs)
+ return ohlc_decr
+
+
+def create_ohlc(open, high, low, close, dates=None, direction='both',
+ **kwargs):
+ """
+ BETA function that creates an ohlc chart
+
+ :param (list) open: opening values
+ :param (list) high: high values
+ :param (list) low: low values
+ :param (list) close: closing
+ :param (list) dates: list of datetime objects. Default: None
+ :param (string) direction: direction can be 'increasing', 'decreasing',
+ or 'both'. When the direction is 'increasing', the returned figure
+ consists of all units where the close value is greater than the
+ corresponding open value, and when the direction is 'decreasing',
+ the returned figure consists of all units where the close value is
+ less than or equal to the corresponding open value. When the
+ direction is 'both', both increasing and decreasing units are
+ returned. Default: 'both'
+ :param kwargs: kwargs passed through plotly.graph_objs.Scatter.
+ These kwargs describe other attributes about the ohlc Scatter trace
+ such as the color or the legend name. For more information on valid
+ kwargs call help(plotly.graph_objs.Scatter)
+
+ :rtype (dict): returns a representation of an ohlc chart figure.
+
+ Example 1: Simple OHLC chart from a Pandas DataFrame
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_ohlc
+ from datetime import datetime
+
+ import pandas.io.data as web
+
+ df = web.DataReader("aapl", 'yahoo', datetime(2008, 8, 15),
+ datetime(2008, 10, 15))
+ fig = create_ohlc(df.Open, df.High, df.Low, df.Close, dates=df.index)
+
+ py.plot(fig, filename='finance/aapl-ohlc')
+ ```
+
+ Example 2: Add text and annotations to the OHLC chart
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_ohlc
+ from datetime import datetime
+
+ import pandas.io.data as web
+
+ df = web.DataReader("aapl", 'yahoo', datetime(2008, 8, 15),
+ datetime(2008, 10, 15))
+ fig = create_ohlc(df.Open, df.High, df.Low, df.Close, dates=df.index)
+
+ # Update the fig - options here: https://plot.ly/python/reference/#Layout
+ fig['layout'].update({
+ 'title': 'The Great Recession',
+ 'yaxis': {'title': 'AAPL Stock'},
+ 'shapes': [{
+ 'x0': '2008-09-15', 'x1': '2008-09-15', 'type': 'line',
+ 'y0': 0, 'y1': 1, 'xref': 'x', 'yref': 'paper',
+ 'line': {'color': 'rgb(40,40,40)', 'width': 0.5}
+ }],
+ 'annotations': [{
+ 'text': "the fall of Lehman Brothers",
+ 'x': '2008-09-15', 'y': 1.02,
+ 'xref': 'x', 'yref': 'paper',
+ 'showarrow': False, 'xanchor': 'left'
+ }]
+ })
+
+ py.plot(fig, filename='finance/aapl-recession-ohlc', validate=False)
+ ```
+
+ Example 3: Customize the OHLC colors
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_ohlc
+ from plotly.graph_objs import Line, Marker
+ from datetime import datetime
+
+ import pandas.io.data as web
+
+ df = web.DataReader("aapl", 'yahoo', datetime(2008, 1, 1),
+ datetime(2009, 4, 1))
+
+ # Make increasing ohlc sticks and customize their color and name
+ fig_increasing = create_ohlc(df.Open, df.High, df.Low, df.Close,
+ dates=df.index, direction='increasing',
+ name='AAPL',
+ line=Line(color='rgb(150, 200, 250)'))
+
+ # Make decreasing ohlc sticks and customize their color and name
+ fig_decreasing = create_ohlc(df.Open, df.High, df.Low, df.Close,
+ dates=df.index, direction='decreasing',
+ line=Line(color='rgb(128, 128, 128)'))
+
+ # Initialize the figure
+ fig = fig_increasing
+
+ # Add decreasing data with .extend()
+ fig['data'].extend(fig_decreasing['data'])
+
+ py.iplot(fig, filename='finance/aapl-ohlc-colors', validate=False)
+ ```
+
+ Example 4: OHLC chart with datetime objects
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_ohlc
+
+ from datetime import datetime
+
+ # Add data
+ open_data = [33.0, 33.3, 33.5, 33.0, 34.1]
+ high_data = [33.1, 33.3, 33.6, 33.2, 34.8]
+ low_data = [32.7, 32.7, 32.8, 32.6, 32.8]
+ close_data = [33.0, 32.9, 33.3, 33.1, 33.1]
+ dates = [datetime(year=2013, month=10, day=10),
+ datetime(year=2013, month=11, day=10),
+ datetime(year=2013, month=12, day=10),
+ datetime(year=2014, month=1, day=10),
+ datetime(year=2014, month=2, day=10)]
+
+ # Create ohlc
+ fig = create_ohlc(open_data, high_data, low_data, close_data, dates=dates)
+
+ py.iplot(fig, filename='finance/simple-ohlc', validate=False)
+ ```
+ """
+ if dates is not None:
+ utils.validate_equal_length(open, high, low, close, dates)
+ else:
+ utils.validate_equal_length(open, high, low, close)
+ validate_ohlc(open, high, low, close, direction, **kwargs)
+
+ if direction is 'increasing':
+ ohlc_incr = make_increasing_ohlc(open, high, low, close, dates,
+ **kwargs)
+ data = [ohlc_incr]
+ elif direction is 'decreasing':
+ ohlc_decr = make_decreasing_ohlc(open, high, low, close, dates,
+ **kwargs)
+ data = [ohlc_decr]
+ else:
+ ohlc_incr = make_increasing_ohlc(open, high, low, close, dates,
+ **kwargs)
+ ohlc_decr = make_decreasing_ohlc(open, high, low, close, dates,
+ **kwargs)
+ data = [ohlc_incr, ohlc_decr]
+
+ layout = graph_objs.Layout(xaxis=dict(zeroline=False),
+ hovermode='closest')
+
+ return graph_objs.Figure(data=data, layout=layout)
+
+
+class _OHLC(object):
+ """
+ Refer to FigureFactory.create_ohlc_increase() for docstring.
+ """
+ def __init__(self, open, high, low, close, dates, **kwargs):
+ self.open = open
+ self.high = high
+ self.low = low
+ self.close = close
+ self.empty = [None] * len(open)
+ self.dates = dates
+
+ self.all_x = []
+ self.all_y = []
+ self.increase_x = []
+ self.increase_y = []
+ self.decrease_x = []
+ self.decrease_y = []
+ self.get_all_xy()
+ self.separate_increase_decrease()
+
+ def get_all_xy(self):
+ """
+ Zip data to create OHLC shape
+
+ OHLC shape: low to high vertical bar with
+ horizontal branches for open and close values.
+ If dates were added, the smallest date difference is calculated and
+ multiplied by .2 to get the length of the open and close branches.
+ If no date data was provided, the x-axis is a list of integers and the
+ length of the open and close branches is .2.
+ """
+ self.all_y = list(zip(self.open, self.open, self.high,
+ self.low, self.close, self.close, self.empty))
+ if self.dates is not None:
+ date_dif = []
+ for i in range(len(self.dates) - 1):
+ date_dif.append(self.dates[i + 1] - self.dates[i])
+ date_dif_min = (min(date_dif)) / 5
+ self.all_x = [[x - date_dif_min, x, x, x, x, x +
+ date_dif_min, None] for x in self.dates]
+ else:
+ self.all_x = [[x - .2, x, x, x, x, x + .2, None]
+ for x in range(len(self.open))]
+
+ def separate_increase_decrease(self):
+ """
+ Separate data into two groups: increase and decrease
+
+ (1) Increase, where close > open and
+ (2) Decrease, where close <= open
+ """
+ for index in range(len(self.open)):
+ if self.close[index] is None:
+ pass
+ elif self.close[index] > self.open[index]:
+ self.increase_x.append(self.all_x[index])
+ self.increase_y.append(self.all_y[index])
+ else:
+ self.decrease_x.append(self.all_x[index])
+ self.decrease_y.append(self.all_y[index])
+
+ def get_increase(self):
+ """
+ Flatten increase data and get increase text
+
+ :rtype (list, list, list): flat_increase_x: x-values for the increasing
+ trace, flat_increase_y: y=values for the increasing trace and
+ text_increase: hovertext for the increasing trace
+ """
+ flat_increase_x = utils.flatten(self.increase_x)
+ flat_increase_y = utils.flatten(self.increase_y)
+ text_increase = (("Open", "Open", "High",
+ "Low", "Close", "Close", '')
+ * (len(self.increase_x)))
+
+ return flat_increase_x, flat_increase_y, text_increase
+
+ def get_decrease(self):
+ """
+ Flatten decrease data and get decrease text
+
+ :rtype (list, list, list): flat_decrease_x: x-values for the decreasing
+ trace, flat_decrease_y: y=values for the decreasing trace and
+ text_decrease: hovertext for the decreasing trace
+ """
+ flat_decrease_x = utils.flatten(self.decrease_x)
+ flat_decrease_y = utils.flatten(self.decrease_y)
+ text_decrease = (("Open", "Open", "High",
+ "Low", "Close", "Close", '')
+ * (len(self.decrease_x)))
+
+ return flat_decrease_x, flat_decrease_y, text_decrease
diff --git a/plotly/figure_factory/_quiver.py b/plotly/figure_factory/_quiver.py
new file mode 100644
index 00000000000..8d0de352baf
--- /dev/null
+++ b/plotly/figure_factory/_quiver.py
@@ -0,0 +1,243 @@
+from __future__ import absolute_import
+
+import math
+
+from plotly import exceptions
+from plotly.graph_objs import graph_objs
+from plotly.figure_factory import utils
+
+
+def create_quiver(x, y, u, v, scale=.1, arrow_scale=.3,
+ angle=math.pi / 9, **kwargs):
+ """
+ Returns data for a quiver plot.
+
+ :param (list|ndarray) x: x coordinates of the arrow locations
+ :param (list|ndarray) y: y coordinates of the arrow locations
+ :param (list|ndarray) u: x components of the arrow vectors
+ :param (list|ndarray) v: y components of the arrow vectors
+ :param (float in [0,1]) scale: scales size of the arrows(ideally to
+ avoid overlap). Default = .1
+ :param (float in [0,1]) arrow_scale: value multiplied to length of barb
+ to get length of arrowhead. Default = .3
+ :param (angle in radians) angle: angle of arrowhead. Default = pi/9
+ :param kwargs: kwargs passed through plotly.graph_objs.Scatter
+ for more information on valid kwargs call
+ help(plotly.graph_objs.Scatter)
+
+ :rtype (dict): returns a representation of quiver figure.
+
+ Example 1: Trivial Quiver
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_quiver
+
+ import math
+
+ # 1 Arrow from (0,0) to (1,1)
+ fig = create_quiver(x=[0], y=[0], u=[1], v=[1], scale=1)
+
+ py.plot(fig, filename='quiver')
+ ```
+
+ Example 2: Quiver plot using meshgrid
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_quiver
+
+ import numpy as np
+ import math
+
+ # Add data
+ x,y = np.meshgrid(np.arange(0, 2, .2), np.arange(0, 2, .2))
+ u = np.cos(x)*y
+ v = np.sin(x)*y
+
+ #Create quiver
+ fig = create_quiver(x, y, u, v)
+
+ # Plot
+ py.plot(fig, filename='quiver')
+ ```
+
+ Example 3: Styling the quiver plot
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_quiver
+ import numpy as np
+ import math
+
+ # Add data
+ x, y = np.meshgrid(np.arange(-np.pi, math.pi, .5),
+ np.arange(-math.pi, math.pi, .5))
+ u = np.cos(x)*y
+ v = np.sin(x)*y
+
+ # Create quiver
+ fig = create_quiver(x, y, u, v, scale=.2, arrow_scale=.3, angle=math.pi/6,
+ name='Wind Velocity', line=Line(width=1))
+
+ # Add title to layout
+ fig['layout'].update(title='Quiver Plot')
+
+ # Plot
+ py.plot(fig, filename='quiver')
+ ```
+ """
+ utils.validate_equal_length(x, y, u, v)
+ utils.validate_positive_scalars(arrow_scale=arrow_scale, scale=scale)
+
+ barb_x, barb_y = _Quiver(x, y, u, v, scale,
+ arrow_scale, angle).get_barbs()
+ arrow_x, arrow_y = _Quiver(x, y, u, v, scale,
+ arrow_scale, angle).get_quiver_arrows()
+ quiver = graph_objs.Scatter(x=barb_x + arrow_x,
+ y=barb_y + arrow_y,
+ mode='lines', **kwargs)
+
+ data = [quiver]
+ layout = graph_objs.Layout(hovermode='closest')
+
+ return graph_objs.Figure(data=data, layout=layout)
+
+
+class _Quiver(object):
+ """
+ Refer to FigureFactory.create_quiver() for docstring
+ """
+ def __init__(self, x, y, u, v,
+ scale, arrow_scale, angle, **kwargs):
+ try:
+ x = utils.flatten(x)
+ except exceptions.PlotlyError:
+ pass
+
+ try:
+ y = utils.flatten(y)
+ except exceptions.PlotlyError:
+ pass
+
+ try:
+ u = utils.flatten(u)
+ except exceptions.PlotlyError:
+ pass
+
+ try:
+ v = utils.flatten(v)
+ except exceptions.PlotlyError:
+ pass
+
+ self.x = x
+ self.y = y
+ self.u = u
+ self.v = v
+ self.scale = scale
+ self.arrow_scale = arrow_scale
+ self.angle = angle
+ self.end_x = []
+ self.end_y = []
+ self.scale_uv()
+ barb_x, barb_y = self.get_barbs()
+ arrow_x, arrow_y = self.get_quiver_arrows()
+
+ def scale_uv(self):
+ """
+ Scales u and v to avoid overlap of the arrows.
+
+ u and v are added to x and y to get the
+ endpoints of the arrows so a smaller scale value will
+ result in less overlap of arrows.
+ """
+ self.u = [i * self.scale for i in self.u]
+ self.v = [i * self.scale for i in self.v]
+
+ def get_barbs(self):
+ """
+ Creates x and y startpoint and endpoint pairs
+
+ After finding the endpoint of each barb this zips startpoint and
+ endpoint pairs to create 2 lists: x_values for barbs and y values
+ for barbs
+
+ :rtype: (list, list) barb_x, barb_y: list of startpoint and endpoint
+ x_value pairs separated by a None to create the barb of the arrow,
+ and list of startpoint and endpoint y_value pairs separated by a
+ None to create the barb of the arrow.
+ """
+ self.end_x = [i + j for i, j in zip(self.x, self.u)]
+ self.end_y = [i + j for i, j in zip(self.y, self.v)]
+ empty = [None] * len(self.x)
+ barb_x = utils.flatten(zip(self.x, self.end_x, empty))
+ barb_y = utils.flatten(zip(self.y, self.end_y, empty))
+ return barb_x, barb_y
+
+ def get_quiver_arrows(self):
+ """
+ Creates lists of x and y values to plot the arrows
+
+ Gets length of each barb then calculates the length of each side of
+ the arrow. Gets angle of barb and applies angle to each side of the
+ arrowhead. Next uses arrow_scale to scale the length of arrowhead and
+ creates x and y values for arrowhead point1 and point2. Finally x and y
+ values for point1, endpoint and point2s for each arrowhead are
+ separated by a None and zipped to create lists of x and y values for
+ the arrows.
+
+ :rtype: (list, list) arrow_x, arrow_y: list of point1, endpoint, point2
+ x_values separated by a None to create the arrowhead and list of
+ point1, endpoint, point2 y_values separated by a None to create
+ the barb of the arrow.
+ """
+ dif_x = [i - j for i, j in zip(self.end_x, self.x)]
+ dif_y = [i - j for i, j in zip(self.end_y, self.y)]
+
+ # Get barb lengths(default arrow length = 30% barb length)
+ barb_len = [None] * len(self.x)
+ for index in range(len(barb_len)):
+ barb_len[index] = math.hypot(dif_x[index], dif_y[index])
+
+ # Make arrow lengths
+ arrow_len = [None] * len(self.x)
+ arrow_len = [i * self.arrow_scale for i in barb_len]
+
+ # Get barb angles
+ barb_ang = [None] * len(self.x)
+ for index in range(len(barb_ang)):
+ barb_ang[index] = math.atan2(dif_y[index], dif_x[index])
+
+ # Set angles to create arrow
+ ang1 = [i + self.angle for i in barb_ang]
+ ang2 = [i - self.angle for i in barb_ang]
+
+ cos_ang1 = [None] * len(ang1)
+ for index in range(len(ang1)):
+ cos_ang1[index] = math.cos(ang1[index])
+ seg1_x = [i * j for i, j in zip(arrow_len, cos_ang1)]
+
+ sin_ang1 = [None] * len(ang1)
+ for index in range(len(ang1)):
+ sin_ang1[index] = math.sin(ang1[index])
+ seg1_y = [i * j for i, j in zip(arrow_len, sin_ang1)]
+
+ cos_ang2 = [None] * len(ang2)
+ for index in range(len(ang2)):
+ cos_ang2[index] = math.cos(ang2[index])
+ seg2_x = [i * j for i, j in zip(arrow_len, cos_ang2)]
+
+ sin_ang2 = [None] * len(ang2)
+ for index in range(len(ang2)):
+ sin_ang2[index] = math.sin(ang2[index])
+ seg2_y = [i * j for i, j in zip(arrow_len, sin_ang2)]
+
+ # Set coordinates to create arrow
+ for index in range(len(self.end_x)):
+ point1_x = [i - j for i, j in zip(self.end_x, seg1_x)]
+ point1_y = [i - j for i, j in zip(self.end_y, seg1_y)]
+ point2_x = [i - j for i, j in zip(self.end_x, seg2_x)]
+ point2_y = [i - j for i, j in zip(self.end_y, seg2_y)]
+
+ # Combine lists to create arrow
+ empty = [None] * len(self.end_x)
+ arrow_x = utils.flatten(zip(point1_x, self.end_x, point2_x, empty))
+ arrow_y = utils.flatten(zip(point1_y, self.end_y, point2_y, empty))
+ return arrow_x, arrow_y
diff --git a/plotly/figure_factory/_scatterplot.py b/plotly/figure_factory/_scatterplot.py
new file mode 100644
index 00000000000..e9926b26f86
--- /dev/null
+++ b/plotly/figure_factory/_scatterplot.py
@@ -0,0 +1,1136 @@
+from __future__ import absolute_import
+
+from plotly import exceptions, optional_imports
+from plotly.figure_factory import utils
+from plotly.graph_objs import graph_objs
+from plotly.tools import make_subplots
+
+pd = optional_imports.get_module('pandas')
+
+DIAG_CHOICES = ['scatter', 'histogram', 'box']
+VALID_COLORMAP_TYPES = ['cat', 'seq']
+
+
+def endpts_to_intervals(endpts):
+ """
+ Returns a list of intervals for categorical colormaps
+
+ Accepts a list or tuple of sequentially increasing numbers and returns
+ a list representation of the mathematical intervals with these numbers
+ as endpoints. For example, [1, 6] returns [[-inf, 1], [1, 6], [6, inf]]
+
+ :raises: (PlotlyError) If input is not a list or tuple
+ :raises: (PlotlyError) If the input contains a string
+ :raises: (PlotlyError) If any number does not increase after the
+ previous one in the sequence
+ """
+ length = len(endpts)
+ # Check if endpts is a list or tuple
+ if not (isinstance(endpts, (tuple)) or isinstance(endpts, (list))):
+ raise exceptions.PlotlyError("The intervals_endpts argument must "
+ "be a list or tuple of a sequence "
+ "of increasing numbers.")
+ # Check if endpts contains only numbers
+ for item in endpts:
+ if isinstance(item, str):
+ raise exceptions.PlotlyError("The intervals_endpts argument "
+ "must be a list or tuple of a "
+ "sequence of increasing "
+ "numbers.")
+ # Check if numbers in endpts are increasing
+ for k in range(length - 1):
+ if endpts[k] >= endpts[k + 1]:
+ raise exceptions.PlotlyError("The intervals_endpts argument "
+ "must be a list or tuple of a "
+ "sequence of increasing "
+ "numbers.")
+ else:
+ intervals = []
+ # add -inf to intervals
+ intervals.append([float('-inf'), endpts[0]])
+ for k in range(length - 1):
+ interval = []
+ interval.append(endpts[k])
+ interval.append(endpts[k + 1])
+ intervals.append(interval)
+ # add +inf to intervals
+ intervals.append([endpts[length - 1], float('inf')])
+ return intervals
+
+
+def hide_tick_labels_from_box_subplots(fig):
+ """
+ Hides tick labels for box plots in scatterplotmatrix subplots.
+ """
+ boxplot_xaxes = []
+ for trace in fig['data']:
+ if trace['type'] == 'box':
+ # stores the xaxes which correspond to boxplot subplots
+ # since we use xaxis1, xaxis2, etc, in plotly.py
+ boxplot_xaxes.append(
+ 'xaxis{}'.format(trace['xaxis'][1:])
+ )
+ for xaxis in boxplot_xaxes:
+ fig['layout'][xaxis]['showticklabels'] = False
+
+
+def validate_scatterplotmatrix(df, index, diag, colormap_type, **kwargs):
+ """
+ Validates basic inputs for FigureFactory.create_scatterplotmatrix()
+
+ :raises: (PlotlyError) If pandas is not imported
+ :raises: (PlotlyError) If pandas dataframe is not inputted
+ :raises: (PlotlyError) If pandas dataframe has <= 1 columns
+ :raises: (PlotlyError) If diagonal plot choice (diag) is not one of
+ the viable options
+ :raises: (PlotlyError) If colormap_type is not a valid choice
+ :raises: (PlotlyError) If kwargs contains 'size', 'color' or
+ 'colorscale'
+ """
+ if not pd:
+ raise ImportError("FigureFactory.scatterplotmatrix requires "
+ "a pandas DataFrame.")
+
+ # Check if pandas dataframe
+ if not isinstance(df, pd.core.frame.DataFrame):
+ raise exceptions.PlotlyError("Dataframe not inputed. Please "
+ "use a pandas dataframe to pro"
+ "duce a scatterplot matrix.")
+
+ # Check if dataframe is 1 column or less
+ if len(df.columns) <= 1:
+ raise exceptions.PlotlyError("Dataframe has only one column. To "
+ "use the scatterplot matrix, use at "
+ "least 2 columns.")
+
+ # Check that diag parameter is a valid selection
+ if diag not in DIAG_CHOICES:
+ raise exceptions.PlotlyError("Make sure diag is set to "
+ "one of {}".format(DIAG_CHOICES))
+
+ # Check that colormap_types is a valid selection
+ if colormap_type not in VALID_COLORMAP_TYPES:
+ raise exceptions.PlotlyError("Must choose a valid colormap type. "
+ "Either 'cat' or 'seq' for a cate"
+ "gorical and sequential colormap "
+ "respectively.")
+
+ # Check for not 'size' or 'color' in 'marker' of **kwargs
+ if 'marker' in kwargs:
+ FORBIDDEN_PARAMS = ['size', 'color', 'colorscale']
+ if any(param in kwargs['marker'] for param in FORBIDDEN_PARAMS):
+ raise exceptions.PlotlyError("Your kwargs dictionary cannot "
+ "include the 'size', 'color' or "
+ "'colorscale' key words inside "
+ "the marker dict since 'size' is "
+ "already an argument of the "
+ "scatterplot matrix function and "
+ "both 'color' and 'colorscale "
+ "are set internally.")
+
+
+def scatterplot(dataframe, headers, diag, size, height, width, title,
+ **kwargs):
+ """
+ Refer to FigureFactory.create_scatterplotmatrix() for docstring
+
+ Returns fig for scatterplotmatrix without index
+
+ """
+ dim = len(dataframe)
+ fig = make_subplots(rows=dim, cols=dim, print_grid=False)
+ trace_list = []
+ # Insert traces into trace_list
+ for listy in dataframe:
+ for listx in dataframe:
+ if (listx == listy) and (diag == 'histogram'):
+ trace = graph_objs.Histogram(
+ x=listx,
+ showlegend=False
+ )
+ elif (listx == listy) and (diag == 'box'):
+ trace = graph_objs.Box(
+ y=listx,
+ name=None,
+ showlegend=False
+ )
+ else:
+ if 'marker' in kwargs:
+ kwargs['marker']['size'] = size
+ trace = graph_objs.Scatter(
+ x=listx,
+ y=listy,
+ mode='markers',
+ showlegend=False,
+ **kwargs
+ )
+ trace_list.append(trace)
+ else:
+ trace = graph_objs.Scatter(
+ x=listx,
+ y=listy,
+ mode='markers',
+ marker=dict(
+ size=size),
+ showlegend=False,
+ **kwargs
+ )
+ trace_list.append(trace)
+
+ trace_index = 0
+ indices = range(1, dim + 1)
+ for y_index in indices:
+ for x_index in indices:
+ fig.append_trace(trace_list[trace_index],
+ y_index,
+ x_index)
+ trace_index += 1
+
+ # Insert headers into the figure
+ for j in range(dim):
+ xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j)
+ fig['layout'][xaxis_key].update(title=headers[j])
+ for j in range(dim):
+ yaxis_key = 'yaxis{}'.format(1 + (dim * j))
+ fig['layout'][yaxis_key].update(title=headers[j])
+
+ fig['layout'].update(
+ height=height, width=width,
+ title=title,
+ showlegend=True
+ )
+
+ hide_tick_labels_from_box_subplots(fig)
+
+ return fig
+
+
+def scatterplot_dict(dataframe, headers, diag, size,
+ height, width, title, index, index_vals,
+ endpts, colormap, colormap_type, **kwargs):
+ """
+ Refer to FigureFactory.create_scatterplotmatrix() for docstring
+
+ Returns fig for scatterplotmatrix with both index and colormap picked.
+ Used if colormap is a dictionary with index values as keys pointing to
+ colors. Forces colormap_type to behave categorically because it would
+ not make sense colors are assigned to each index value and thus
+ implies that a categorical approach should be taken
+
+ """
+
+ theme = colormap
+ dim = len(dataframe)
+ fig = make_subplots(rows=dim, cols=dim, print_grid=False)
+ trace_list = []
+ legend_param = 0
+ # Work over all permutations of list pairs
+ for listy in dataframe:
+ for listx in dataframe:
+ # create a dictionary for index_vals
+ unique_index_vals = {}
+ for name in index_vals:
+ if name not in unique_index_vals:
+ unique_index_vals[name] = []
+
+ # Fill all the rest of the names into the dictionary
+ for name in sorted(unique_index_vals.keys()):
+ new_listx = []
+ new_listy = []
+ for j in range(len(index_vals)):
+ if index_vals[j] == name:
+ new_listx.append(listx[j])
+ new_listy.append(listy[j])
+ # Generate trace with VISIBLE icon
+ if legend_param == 1:
+ if (listx == listy) and (diag == 'histogram'):
+ trace = graph_objs.Histogram(
+ x=new_listx,
+ marker=dict(
+ color=theme[name]),
+ showlegend=True
+ )
+ elif (listx == listy) and (diag == 'box'):
+ trace = graph_objs.Box(
+ y=new_listx,
+ name=None,
+ marker=dict(
+ color=theme[name]),
+ showlegend=True
+ )
+ else:
+ if 'marker' in kwargs:
+ kwargs['marker']['size'] = size
+ kwargs['marker']['color'] = theme[name]
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode='markers',
+ name=name,
+ showlegend=True,
+ **kwargs
+ )
+ else:
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode='markers',
+ name=name,
+ marker=dict(
+ size=size,
+ color=theme[name]),
+ showlegend=True,
+ **kwargs
+ )
+ # Generate trace with INVISIBLE icon
+ else:
+ if (listx == listy) and (diag == 'histogram'):
+ trace = graph_objs.Histogram(
+ x=new_listx,
+ marker=dict(
+ color=theme[name]),
+ showlegend=False
+ )
+ elif (listx == listy) and (diag == 'box'):
+ trace = graph_objs.Box(
+ y=new_listx,
+ name=None,
+ marker=dict(
+ color=theme[name]),
+ showlegend=False
+ )
+ else:
+ if 'marker' in kwargs:
+ kwargs['marker']['size'] = size
+ kwargs['marker']['color'] = theme[name]
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode='markers',
+ name=name,
+ showlegend=False,
+ **kwargs
+ )
+ else:
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode='markers',
+ name=name,
+ marker=dict(
+ size=size,
+ color=theme[name]),
+ showlegend=False,
+ **kwargs
+ )
+ # Push the trace into dictionary
+ unique_index_vals[name] = trace
+ trace_list.append(unique_index_vals)
+ legend_param += 1
+
+ trace_index = 0
+ indices = range(1, dim + 1)
+ for y_index in indices:
+ for x_index in indices:
+ for name in sorted(trace_list[trace_index].keys()):
+ fig.append_trace(
+ trace_list[trace_index][name],
+ y_index,
+ x_index)
+ trace_index += 1
+
+ # Insert headers into the figure
+ for j in range(dim):
+ xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j)
+ fig['layout'][xaxis_key].update(title=headers[j])
+
+ for j in range(dim):
+ yaxis_key = 'yaxis{}'.format(1 + (dim * j))
+ fig['layout'][yaxis_key].update(title=headers[j])
+
+ hide_tick_labels_from_box_subplots(fig)
+
+ if diag == 'histogram':
+ fig['layout'].update(
+ height=height, width=width,
+ title=title,
+ showlegend=True,
+ barmode='stack')
+ return fig
+
+ else:
+ fig['layout'].update(
+ height=height, width=width,
+ title=title,
+ showlegend=True)
+ return fig
+
+
+def scatterplot_theme(dataframe, headers, diag, size, height, width, title,
+ index, index_vals, endpts, colormap, colormap_type,
+ **kwargs):
+ """
+ Refer to FigureFactory.create_scatterplotmatrix() for docstring
+
+ Returns fig for scatterplotmatrix with both index and colormap picked
+
+ """
+
+ # Check if index is made of string values
+ if isinstance(index_vals[0], str):
+ unique_index_vals = []
+ for name in index_vals:
+ if name not in unique_index_vals:
+ unique_index_vals.append(name)
+ n_colors_len = len(unique_index_vals)
+
+ # Convert colormap to list of n RGB tuples
+ if colormap_type == 'seq':
+ foo = utils.color_parser(colormap, utils.unlabel_rgb)
+ foo = utils.n_colors(foo[0], foo[1], n_colors_len)
+ theme = utils.color_parser(foo, utils.label_rgb)
+
+ if colormap_type == 'cat':
+ # leave list of colors the same way
+ theme = colormap
+
+ dim = len(dataframe)
+ fig = make_subplots(rows=dim, cols=dim, print_grid=False)
+ trace_list = []
+ legend_param = 0
+ # Work over all permutations of list pairs
+ for listy in dataframe:
+ for listx in dataframe:
+ # create a dictionary for index_vals
+ unique_index_vals = {}
+ for name in index_vals:
+ if name not in unique_index_vals:
+ unique_index_vals[name] = []
+
+ c_indx = 0 # color index
+ # Fill all the rest of the names into the dictionary
+ for name in sorted(unique_index_vals.keys()):
+ new_listx = []
+ new_listy = []
+ for j in range(len(index_vals)):
+ if index_vals[j] == name:
+ new_listx.append(listx[j])
+ new_listy.append(listy[j])
+ # Generate trace with VISIBLE icon
+ if legend_param == 1:
+ if (listx == listy) and (diag == 'histogram'):
+ trace = graph_objs.Histogram(
+ x=new_listx,
+ marker=dict(
+ color=theme[c_indx]),
+ showlegend=True
+ )
+ elif (listx == listy) and (diag == 'box'):
+ trace = graph_objs.Box(
+ y=new_listx,
+ name=None,
+ marker=dict(
+ color=theme[c_indx]),
+ showlegend=True
+ )
+ else:
+ if 'marker' in kwargs:
+ kwargs['marker']['size'] = size
+ kwargs['marker']['color'] = theme[c_indx]
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode='markers',
+ name=name,
+ showlegend=True,
+ **kwargs
+ )
+ else:
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode='markers',
+ name=name,
+ marker=dict(
+ size=size,
+ color=theme[c_indx]),
+ showlegend=True,
+ **kwargs
+ )
+ # Generate trace with INVISIBLE icon
+ else:
+ if (listx == listy) and (diag == 'histogram'):
+ trace = graph_objs.Histogram(
+ x=new_listx,
+ marker=dict(
+ color=theme[c_indx]),
+ showlegend=False
+ )
+ elif (listx == listy) and (diag == 'box'):
+ trace = graph_objs.Box(
+ y=new_listx,
+ name=None,
+ marker=dict(
+ color=theme[c_indx]),
+ showlegend=False
+ )
+ else:
+ if 'marker' in kwargs:
+ kwargs['marker']['size'] = size
+ kwargs['marker']['color'] = theme[c_indx]
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode='markers',
+ name=name,
+ showlegend=False,
+ **kwargs
+ )
+ else:
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode='markers',
+ name=name,
+ marker=dict(
+ size=size,
+ color=theme[c_indx]),
+ showlegend=False,
+ **kwargs
+ )
+ # Push the trace into dictionary
+ unique_index_vals[name] = trace
+ if c_indx >= (len(theme) - 1):
+ c_indx = -1
+ c_indx += 1
+ trace_list.append(unique_index_vals)
+ legend_param += 1
+
+ trace_index = 0
+ indices = range(1, dim + 1)
+ for y_index in indices:
+ for x_index in indices:
+ for name in sorted(trace_list[trace_index].keys()):
+ fig.append_trace(
+ trace_list[trace_index][name],
+ y_index,
+ x_index)
+ trace_index += 1
+
+ # Insert headers into the figure
+ for j in range(dim):
+ xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j)
+ fig['layout'][xaxis_key].update(title=headers[j])
+
+ for j in range(dim):
+ yaxis_key = 'yaxis{}'.format(1 + (dim * j))
+ fig['layout'][yaxis_key].update(title=headers[j])
+
+ hide_tick_labels_from_box_subplots(fig)
+
+ if diag == 'histogram':
+ fig['layout'].update(
+ height=height, width=width,
+ title=title,
+ showlegend=True,
+ barmode='stack')
+ return fig
+
+ elif diag == 'box':
+ fig['layout'].update(
+ height=height, width=width,
+ title=title,
+ showlegend=True)
+ return fig
+
+ else:
+ fig['layout'].update(
+ height=height, width=width,
+ title=title,
+ showlegend=True)
+ return fig
+
+ else:
+ if endpts:
+ intervals = endpts_to_intervals(endpts)
+
+ # Convert colormap to list of n RGB tuples
+ if colormap_type == 'seq':
+ foo = utils.color_parser(colormap, utils.unlabel_rgb)
+ foo = utils.n_colors(foo[0], foo[1], len(intervals))
+ theme = utils.color_parser(foo, utils.label_rgb)
+
+ if colormap_type == 'cat':
+ # leave list of colors the same way
+ theme = colormap
+
+ dim = len(dataframe)
+ fig = make_subplots(rows=dim, cols=dim, print_grid=False)
+ trace_list = []
+ legend_param = 0
+ # Work over all permutations of list pairs
+ for listy in dataframe:
+ for listx in dataframe:
+ interval_labels = {}
+ for interval in intervals:
+ interval_labels[str(interval)] = []
+
+ c_indx = 0 # color index
+ # Fill all the rest of the names into the dictionary
+ for interval in intervals:
+ new_listx = []
+ new_listy = []
+ for j in range(len(index_vals)):
+ if interval[0] < index_vals[j] <= interval[1]:
+ new_listx.append(listx[j])
+ new_listy.append(listy[j])
+ # Generate trace with VISIBLE icon
+ if legend_param == 1:
+ if (listx == listy) and (diag == 'histogram'):
+ trace = graph_objs.Histogram(
+ x=new_listx,
+ marker=dict(
+ color=theme[c_indx]),
+ showlegend=True
+ )
+ elif (listx == listy) and (diag == 'box'):
+ trace = graph_objs.Box(
+ y=new_listx,
+ name=None,
+ marker=dict(
+ color=theme[c_indx]),
+ showlegend=True
+ )
+ else:
+ if 'marker' in kwargs:
+ kwargs['marker']['size'] = size
+ (kwargs['marker']
+ ['color']) = theme[c_indx]
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode='markers',
+ name=str(interval),
+ showlegend=True,
+ **kwargs
+ )
+ else:
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode='markers',
+ name=str(interval),
+ marker=dict(
+ size=size,
+ color=theme[c_indx]),
+ showlegend=True,
+ **kwargs
+ )
+ # Generate trace with INVISIBLE icon
+ else:
+ if (listx == listy) and (diag == 'histogram'):
+ trace = graph_objs.Histogram(
+ x=new_listx,
+ marker=dict(
+ color=theme[c_indx]),
+ showlegend=False
+ )
+ elif (listx == listy) and (diag == 'box'):
+ trace = graph_objs.Box(
+ y=new_listx,
+ name=None,
+ marker=dict(
+ color=theme[c_indx]),
+ showlegend=False
+ )
+ else:
+ if 'marker' in kwargs:
+ kwargs['marker']['size'] = size
+ (kwargs['marker']
+ ['color']) = theme[c_indx]
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode='markers',
+ name=str(interval),
+ showlegend=False,
+ **kwargs
+ )
+ else:
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode='markers',
+ name=str(interval),
+ marker=dict(
+ size=size,
+ color=theme[c_indx]),
+ showlegend=False,
+ **kwargs
+ )
+ # Push the trace into dictionary
+ interval_labels[str(interval)] = trace
+ if c_indx >= (len(theme) - 1):
+ c_indx = -1
+ c_indx += 1
+ trace_list.append(interval_labels)
+ legend_param += 1
+
+ trace_index = 0
+ indices = range(1, dim + 1)
+ for y_index in indices:
+ for x_index in indices:
+ for interval in intervals:
+ fig.append_trace(
+ trace_list[trace_index][str(interval)],
+ y_index,
+ x_index)
+ trace_index += 1
+
+ # Insert headers into the figure
+ for j in range(dim):
+ xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j)
+ fig['layout'][xaxis_key].update(title=headers[j])
+ for j in range(dim):
+ yaxis_key = 'yaxis{}'.format(1 + (dim * j))
+ fig['layout'][yaxis_key].update(title=headers[j])
+
+ hide_tick_labels_from_box_subplots(fig)
+
+ if diag == 'histogram':
+ fig['layout'].update(
+ height=height, width=width,
+ title=title,
+ showlegend=True,
+ barmode='stack')
+ return fig
+
+ elif diag == 'box':
+ fig['layout'].update(
+ height=height, width=width,
+ title=title,
+ showlegend=True)
+ return fig
+
+ else:
+ fig['layout'].update(
+ height=height, width=width,
+ title=title,
+ showlegend=True)
+ return fig
+
+ else:
+ theme = colormap
+
+ # add a copy of rgb color to theme if it contains one color
+ if len(theme) <= 1:
+ theme.append(theme[0])
+
+ color = []
+ for incr in range(len(theme)):
+ color.append([1. / (len(theme) - 1) * incr, theme[incr]])
+
+ dim = len(dataframe)
+ fig = make_subplots(rows=dim, cols=dim, print_grid=False)
+ trace_list = []
+ legend_param = 0
+ # Run through all permutations of list pairs
+ for listy in dataframe:
+ for listx in dataframe:
+ # Generate trace with VISIBLE icon
+ if legend_param == 1:
+ if (listx == listy) and (diag == 'histogram'):
+ trace = graph_objs.Histogram(
+ x=listx,
+ marker=dict(
+ color=theme[0]),
+ showlegend=False
+ )
+ elif (listx == listy) and (diag == 'box'):
+ trace = graph_objs.Box(
+ y=listx,
+ marker=dict(
+ color=theme[0]),
+ showlegend=False
+ )
+ else:
+ if 'marker' in kwargs:
+ kwargs['marker']['size'] = size
+ kwargs['marker']['color'] = index_vals
+ kwargs['marker']['colorscale'] = color
+ kwargs['marker']['showscale'] = True
+ trace = graph_objs.Scatter(
+ x=listx,
+ y=listy,
+ mode='markers',
+ showlegend=False,
+ **kwargs
+ )
+ else:
+ trace = graph_objs.Scatter(
+ x=listx,
+ y=listy,
+ mode='markers',
+ marker=dict(
+ size=size,
+ color=index_vals,
+ colorscale=color,
+ showscale=True),
+ showlegend=False,
+ **kwargs
+ )
+ # Generate trace with INVISIBLE icon
+ else:
+ if (listx == listy) and (diag == 'histogram'):
+ trace = graph_objs.Histogram(
+ x=listx,
+ marker=dict(
+ color=theme[0]),
+ showlegend=False
+ )
+ elif (listx == listy) and (diag == 'box'):
+ trace = graph_objs.Box(
+ y=listx,
+ marker=dict(
+ color=theme[0]),
+ showlegend=False
+ )
+ else:
+ if 'marker' in kwargs:
+ kwargs['marker']['size'] = size
+ kwargs['marker']['color'] = index_vals
+ kwargs['marker']['colorscale'] = color
+ kwargs['marker']['showscale'] = False
+ trace = graph_objs.Scatter(
+ x=listx,
+ y=listy,
+ mode='markers',
+ showlegend=False,
+ **kwargs
+ )
+ else:
+ trace = graph_objs.Scatter(
+ x=listx,
+ y=listy,
+ mode='markers',
+ marker=dict(
+ size=size,
+ color=index_vals,
+ colorscale=color,
+ showscale=False),
+ showlegend=False,
+ **kwargs
+ )
+ # Push the trace into list
+ trace_list.append(trace)
+ legend_param += 1
+
+ trace_index = 0
+ indices = range(1, dim + 1)
+ for y_index in indices:
+ for x_index in indices:
+ fig.append_trace(trace_list[trace_index],
+ y_index,
+ x_index)
+ trace_index += 1
+
+ # Insert headers into the figure
+ for j in range(dim):
+ xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j)
+ fig['layout'][xaxis_key].update(title=headers[j])
+ for j in range(dim):
+ yaxis_key = 'yaxis{}'.format(1 + (dim * j))
+ fig['layout'][yaxis_key].update(title=headers[j])
+
+ hide_tick_labels_from_box_subplots(fig)
+
+ if diag == 'histogram':
+ fig['layout'].update(
+ height=height, width=width,
+ title=title,
+ showlegend=True,
+ barmode='stack')
+ return fig
+
+ elif diag == 'box':
+ fig['layout'].update(
+ height=height, width=width,
+ title=title,
+ showlegend=True)
+ return fig
+
+ else:
+ fig['layout'].update(
+ height=height, width=width,
+ title=title,
+ showlegend=True)
+ return fig
+
+
+def create_scatterplotmatrix(df, index=None, endpts=None, diag='scatter',
+ height=500, width=500, size=6,
+ title='Scatterplot Matrix', colormap=None,
+ colormap_type='cat', dataframe=None,
+ headers=None, index_vals=None, **kwargs):
+ """
+ Returns data for a scatterplot matrix.
+
+ :param (array) df: array of the data with column headers
+ :param (str) index: name of the index column in data array
+ :param (list|tuple) endpts: takes an increasing sequece of numbers
+ that defines intervals on the real line. They are used to group
+ the entries in an index of numbers into their corresponding
+ interval and therefore can be treated as categorical data
+ :param (str) diag: sets the chart type for the main diagonal plots.
+ The options are 'scatter', 'histogram' and 'box'.
+ :param (int|float) height: sets the height of the chart
+ :param (int|float) width: sets the width of the chart
+ :param (float) size: sets the marker size (in px)
+ :param (str) title: the title label of the scatterplot matrix
+ :param (str|tuple|list|dict) colormap: either a plotly scale name,
+ an rgb or hex color, a color tuple, a list of colors or a
+ dictionary. An rgb color is of the form 'rgb(x, y, z)' where
+ x, y and z belong to the interval [0, 255] and a color tuple is a
+ tuple of the form (a, b, c) where a, b and c belong to [0, 1].
+ If colormap is a list, it must contain valid color types as its
+ members.
+ If colormap is a dictionary, all the string entries in
+ the index column must be a key in colormap. In this case, the
+ colormap_type is forced to 'cat' or categorical
+ :param (str) colormap_type: determines how colormap is interpreted.
+ Valid choices are 'seq' (sequential) and 'cat' (categorical). If
+ 'seq' is selected, only the first two colors in colormap will be
+ considered (when colormap is a list) and the index values will be
+ linearly interpolated between those two colors. This option is
+ forced if all index values are numeric.
+ If 'cat' is selected, a color from colormap will be assigned to
+ each category from index, including the intervals if endpts is
+ being used
+ :param (dict) **kwargs: a dictionary of scatterplot arguments
+ The only forbidden parameters are 'size', 'color' and
+ 'colorscale' in 'marker'
+
+ Example 1: Vanilla Scatterplot Matrix
+ ```
+ import plotly.plotly as py
+ from plotly.graph_objs import graph_objs
+ from plotly.figure_factory import create_scatterplotmatrix
+
+ import numpy as np
+ import pandas as pd
+
+ # Create dataframe
+ df = pd.DataFrame(np.random.randn(10, 2),
+ columns=['Column 1', 'Column 2'])
+
+ # Create scatterplot matrix
+ fig = create_scatterplotmatrix(df)
+
+ # Plot
+ py.iplot(fig, filename='Vanilla Scatterplot Matrix')
+ ```
+
+ Example 2: Indexing a Column
+ ```
+ import plotly.plotly as py
+ from plotly.graph_objs import graph_objs
+ from plotly.figure_factory import create_scatterplotmatrix
+
+ import numpy as np
+ import pandas as pd
+
+ # Create dataframe with index
+ df = pd.DataFrame(np.random.randn(10, 2),
+ columns=['A', 'B'])
+
+ # Add another column of strings to the dataframe
+ df['Fruit'] = pd.Series(['apple', 'apple', 'grape', 'apple', 'apple',
+ 'grape', 'pear', 'pear', 'apple', 'pear'])
+
+ # Create scatterplot matrix
+ fig = create_scatterplotmatrix(df, index='Fruit', size=10)
+
+ # Plot
+ py.iplot(fig, filename = 'Scatterplot Matrix with Index')
+ ```
+
+ Example 3: Styling the Diagonal Subplots
+ ```
+ import plotly.plotly as py
+ from plotly.graph_objs import graph_objs
+ from plotly.figure_factory import create_scatterplotmatrix
+
+ import numpy as np
+ import pandas as pd
+
+ # Create dataframe with index
+ df = pd.DataFrame(np.random.randn(10, 4),
+ columns=['A', 'B', 'C', 'D'])
+
+ # Add another column of strings to the dataframe
+ df['Fruit'] = pd.Series(['apple', 'apple', 'grape', 'apple', 'apple',
+ 'grape', 'pear', 'pear', 'apple', 'pear'])
+
+ # Create scatterplot matrix
+ fig = create_scatterplotmatrix(df, diag='box', index='Fruit', height=1000,
+ width=1000)
+
+ # Plot
+ py.iplot(fig, filename = 'Scatterplot Matrix - Diagonal Styling')
+ ```
+
+ Example 4: Use a Theme to Style the Subplots
+ ```
+ import plotly.plotly as py
+ from plotly.graph_objs import graph_objs
+ from plotly.figure_factory import create_scatterplotmatrix
+
+ import numpy as np
+ import pandas as pd
+
+ # Create dataframe with random data
+ df = pd.DataFrame(np.random.randn(100, 3),
+ columns=['A', 'B', 'C'])
+
+ # Create scatterplot matrix using a built-in
+ # Plotly palette scale and indexing column 'A'
+ fig = create_scatterplotmatrix(df, diag='histogram', index='A',
+ colormap='Blues', height=800, width=800)
+
+ # Plot
+ py.iplot(fig, filename = 'Scatterplot Matrix - Colormap Theme')
+ ```
+
+ Example 5: Example 4 with Interval Factoring
+ ```
+ import plotly.plotly as py
+ from plotly.graph_objs import graph_objs
+ from plotly.figure_factory import create_scatterplotmatrix
+
+ import numpy as np
+ import pandas as pd
+
+ # Create dataframe with random data
+ df = pd.DataFrame(np.random.randn(100, 3),
+ columns=['A', 'B', 'C'])
+
+ # Create scatterplot matrix using a list of 2 rgb tuples
+ # and endpoints at -1, 0 and 1
+ fig = create_scatterplotmatrix(df, diag='histogram', index='A',
+ colormap=['rgb(140, 255, 50)',
+ 'rgb(170, 60, 115)', '#6c4774',
+ (0.5, 0.1, 0.8)],
+ endpts=[-1, 0, 1], height=800, width=800)
+
+ # Plot
+ py.iplot(fig, filename = 'Scatterplot Matrix - Intervals')
+ ```
+
+ Example 6: Using the colormap as a Dictionary
+ ```
+ import plotly.plotly as py
+ from plotly.graph_objs import graph_objs
+ from plotly.figure_factory import create_scatterplotmatrix
+
+ import numpy as np
+ import pandas as pd
+ import random
+
+ # Create dataframe with random data
+ df = pd.DataFrame(np.random.randn(100, 3),
+ columns=['Column A',
+ 'Column B',
+ 'Column C'])
+
+ # Add new color column to dataframe
+ new_column = []
+ strange_colors = ['turquoise', 'limegreen', 'goldenrod']
+
+ for j in range(100):
+ new_column.append(random.choice(strange_colors))
+ df['Colors'] = pd.Series(new_column, index=df.index)
+
+ # Create scatterplot matrix using a dictionary of hex color values
+ # which correspond to actual color names in 'Colors' column
+ fig = create_scatterplotmatrix(
+ df, diag='box', index='Colors',
+ colormap= dict(
+ turquoise = '#00F5FF',
+ limegreen = '#32CD32',
+ goldenrod = '#DAA520'
+ ),
+ colormap_type='cat',
+ height=800, width=800
+ )
+
+ # Plot
+ py.iplot(fig, filename = 'Scatterplot Matrix - colormap dictionary ')
+ ```
+ """
+ # TODO: protected until #282
+ if dataframe is None:
+ dataframe = []
+ if headers is None:
+ headers = []
+ if index_vals is None:
+ index_vals = []
+
+ validate_scatterplotmatrix(df, index, diag, colormap_type, **kwargs)
+
+ # Validate colormap
+ if isinstance(colormap, dict):
+ colormap = utils.validate_colors_dict(colormap, 'rgb')
+ else:
+ colormap = utils.validate_colors(colormap, 'rgb')
+
+ if not index:
+ for name in df:
+ headers.append(name)
+ for name in headers:
+ dataframe.append(df[name].values.tolist())
+ # Check for same data-type in df columns
+ utils.validate_dataframe(dataframe)
+ figure = scatterplot(dataframe, headers, diag, size, height, width,
+ title, **kwargs)
+ return figure
+ else:
+ # Validate index selection
+ if index not in df:
+ raise exceptions.PlotlyError("Make sure you set the index "
+ "input variable to one of the "
+ "column names of your "
+ "dataframe.")
+ index_vals = df[index].values.tolist()
+ for name in df:
+ if name != index:
+ headers.append(name)
+ for name in headers:
+ dataframe.append(df[name].values.tolist())
+
+ # check for same data-type in each df column
+ utils.validate_dataframe(dataframe)
+ utils.validate_index(index_vals)
+
+ # check if all colormap keys are in the index
+ # if colormap is a dictionary
+ if isinstance(colormap, dict):
+ for key in colormap:
+ if not all(index in colormap for index in index_vals):
+ raise exceptions.PlotlyError("If colormap is a "
+ "dictionary, all the "
+ "names in the index "
+ "must be keys.")
+ figure = scatterplot_dict(
+ dataframe, headers, diag, size, height, width, title,
+ index, index_vals, endpts, colormap, colormap_type,
+ **kwargs
+ )
+ return figure
+
+ else:
+ figure = scatterplot_theme(
+ dataframe, headers, diag, size, height, width, title,
+ index, index_vals, endpts, colormap, colormap_type,
+ **kwargs
+ )
+ return figure
diff --git a/plotly/figure_factory/_streamline.py b/plotly/figure_factory/_streamline.py
new file mode 100644
index 00000000000..ddc14778c43
--- /dev/null
+++ b/plotly/figure_factory/_streamline.py
@@ -0,0 +1,411 @@
+from __future__ import absolute_import
+
+import math
+
+from plotly import exceptions, optional_imports
+from plotly.figure_factory import utils
+from plotly.graph_objs import graph_objs
+
+np = optional_imports.get_module('numpy')
+
+
+def validate_streamline(x, y):
+ """
+ Streamline-specific validations
+
+ Specifically, this checks that x and y are both evenly spaced,
+ and that the package numpy is available.
+
+ See FigureFactory.create_streamline() for params
+
+ :raises: (ImportError) If numpy is not available.
+ :raises: (PlotlyError) If x is not evenly spaced.
+ :raises: (PlotlyError) If y is not evenly spaced.
+ """
+ if np is False:
+ raise ImportError("FigureFactory.create_streamline requires numpy")
+ for index in range(len(x) - 1):
+ if ((x[index + 1] - x[index]) - (x[1] - x[0])) > .0001:
+ raise exceptions.PlotlyError("x must be a 1 dimensional, "
+ "evenly spaced array")
+ for index in range(len(y) - 1):
+ if ((y[index + 1] - y[index]) -
+ (y[1] - y[0])) > .0001:
+ raise exceptions.PlotlyError("y must be a 1 dimensional, "
+ "evenly spaced array")
+
+
+def create_streamline(x, y, u, v, density=1, angle=math.pi / 9,
+ arrow_scale=.09, **kwargs):
+ """
+ Returns data for a streamline plot.
+
+ :param (list|ndarray) x: 1 dimensional, evenly spaced list or array
+ :param (list|ndarray) y: 1 dimensional, evenly spaced list or array
+ :param (ndarray) u: 2 dimensional array
+ :param (ndarray) v: 2 dimensional array
+ :param (float|int) density: controls the density of streamlines in
+ plot. This is multiplied by 30 to scale similiarly to other
+ available streamline functions such as matplotlib.
+ Default = 1
+ :param (angle in radians) angle: angle of arrowhead. Default = pi/9
+ :param (float in [0,1]) arrow_scale: value to scale length of arrowhead
+ Default = .09
+ :param kwargs: kwargs passed through plotly.graph_objs.Scatter
+ for more information on valid kwargs call
+ help(plotly.graph_objs.Scatter)
+
+ :rtype (dict): returns a representation of streamline figure.
+
+ Example 1: Plot simple streamline and increase arrow size
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_streamline
+
+ import numpy as np
+ import math
+
+ # Add data
+ x = np.linspace(-3, 3, 100)
+ y = np.linspace(-3, 3, 100)
+ Y, X = np.meshgrid(x, y)
+ u = -1 - X**2 + Y
+ v = 1 + X - Y**2
+ u = u.T # Transpose
+ v = v.T # Transpose
+
+ # Create streamline
+ fig = create_streamline(x, y, u, v, arrow_scale=.1)
+
+ # Plot
+ py.plot(fig, filename='streamline')
+ ```
+
+ Example 2: from nbviewer.ipython.org/github/barbagroup/AeroPython
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_streamline
+
+ import numpy as np
+ import math
+
+ # Add data
+ N = 50
+ x_start, x_end = -2.0, 2.0
+ y_start, y_end = -1.0, 1.0
+ x = np.linspace(x_start, x_end, N)
+ y = np.linspace(y_start, y_end, N)
+ X, Y = np.meshgrid(x, y)
+ ss = 5.0
+ x_s, y_s = -1.0, 0.0
+
+ # Compute the velocity field on the mesh grid
+ u_s = ss/(2*np.pi) * (X-x_s)/((X-x_s)**2 + (Y-y_s)**2)
+ v_s = ss/(2*np.pi) * (Y-y_s)/((X-x_s)**2 + (Y-y_s)**2)
+
+ # Create streamline
+ fig = create_streamline(x, y, u_s, v_s, density=2, name='streamline')
+
+ # Add source point
+ point = Scatter(x=[x_s], y=[y_s], mode='markers',
+ marker=Marker(size=14), name='source point')
+
+ # Plot
+ fig['data'].append(point)
+ py.plot(fig, filename='streamline')
+ ```
+ """
+ utils.validate_equal_length(x, y)
+ utils.validate_equal_length(u, v)
+ validate_streamline(x, y)
+ utils.validate_positive_scalars(density=density, arrow_scale=arrow_scale)
+
+ streamline_x, streamline_y = _Streamline(x, y, u, v,
+ density, angle,
+ arrow_scale).sum_streamlines()
+ arrow_x, arrow_y = _Streamline(x, y, u, v,
+ density, angle,
+ arrow_scale).get_streamline_arrows()
+
+ streamline = graph_objs.Scatter(x=streamline_x + arrow_x,
+ y=streamline_y + arrow_y,
+ mode='lines', **kwargs)
+
+ data = [streamline]
+ layout = graph_objs.Layout(hovermode='closest')
+
+ return graph_objs.Figure(data=data, layout=layout)
+
+
+class _Streamline(object):
+ """
+ Refer to FigureFactory.create_streamline() for docstring
+ """
+ def __init__(self, x, y, u, v,
+ density, angle,
+ arrow_scale, **kwargs):
+ self.x = np.array(x)
+ self.y = np.array(y)
+ self.u = np.array(u)
+ self.v = np.array(v)
+ self.angle = angle
+ self.arrow_scale = arrow_scale
+ self.density = int(30 * density) # Scale similarly to other functions
+ self.delta_x = self.x[1] - self.x[0]
+ self.delta_y = self.y[1] - self.y[0]
+ self.val_x = self.x
+ self.val_y = self.y
+
+ # Set up spacing
+ self.blank = np.zeros((self.density, self.density))
+ self.spacing_x = len(self.x) / float(self.density - 1)
+ self.spacing_y = len(self.y) / float(self.density - 1)
+ self.trajectories = []
+
+ # Rescale speed onto axes-coordinates
+ self.u = self.u / (self.x[-1] - self.x[0])
+ self.v = self.v / (self.y[-1] - self.y[0])
+ self.speed = np.sqrt(self.u ** 2 + self.v ** 2)
+
+ # Rescale u and v for integrations.
+ self.u *= len(self.x)
+ self.v *= len(self.y)
+ self.st_x = []
+ self.st_y = []
+ self.get_streamlines()
+ streamline_x, streamline_y = self.sum_streamlines()
+ arrows_x, arrows_y = self.get_streamline_arrows()
+
+ def blank_pos(self, xi, yi):
+ """
+ Set up positions for trajectories to be used with rk4 function.
+ """
+ return (int((xi / self.spacing_x) + 0.5),
+ int((yi / self.spacing_y) + 0.5))
+
+ def value_at(self, a, xi, yi):
+ """
+ Set up for RK4 function, based on Bokeh's streamline code
+ """
+ if isinstance(xi, np.ndarray):
+ self.x = xi.astype(np.int)
+ self.y = yi.astype(np.int)
+ else:
+ self.val_x = np.int(xi)
+ self.val_y = np.int(yi)
+ a00 = a[self.val_y, self.val_x]
+ a01 = a[self.val_y, self.val_x + 1]
+ a10 = a[self.val_y + 1, self.val_x]
+ a11 = a[self.val_y + 1, self.val_x + 1]
+ xt = xi - self.val_x
+ yt = yi - self.val_y
+ a0 = a00 * (1 - xt) + a01 * xt
+ a1 = a10 * (1 - xt) + a11 * xt
+ return a0 * (1 - yt) + a1 * yt
+
+ def rk4_integrate(self, x0, y0):
+ """
+ RK4 forward and back trajectories from the initial conditions.
+
+ Adapted from Bokeh's streamline -uses Runge-Kutta method to fill
+ x and y trajectories then checks length of traj (s in units of axes)
+ """
+ def f(xi, yi):
+ dt_ds = 1. / self.value_at(self.speed, xi, yi)
+ ui = self.value_at(self.u, xi, yi)
+ vi = self.value_at(self.v, xi, yi)
+ return ui * dt_ds, vi * dt_ds
+
+ def g(xi, yi):
+ dt_ds = 1. / self.value_at(self.speed, xi, yi)
+ ui = self.value_at(self.u, xi, yi)
+ vi = self.value_at(self.v, xi, yi)
+ return -ui * dt_ds, -vi * dt_ds
+
+ check = lambda xi, yi: (0 <= xi < len(self.x) - 1 and
+ 0 <= yi < len(self.y) - 1)
+ xb_changes = []
+ yb_changes = []
+
+ def rk4(x0, y0, f):
+ ds = 0.01
+ stotal = 0
+ xi = x0
+ yi = y0
+ xb, yb = self.blank_pos(xi, yi)
+ xf_traj = []
+ yf_traj = []
+ while check(xi, yi):
+ xf_traj.append(xi)
+ yf_traj.append(yi)
+ try:
+ k1x, k1y = f(xi, yi)
+ k2x, k2y = f(xi + .5 * ds * k1x, yi + .5 * ds * k1y)
+ k3x, k3y = f(xi + .5 * ds * k2x, yi + .5 * ds * k2y)
+ k4x, k4y = f(xi + ds * k3x, yi + ds * k3y)
+ except IndexError:
+ break
+ xi += ds * (k1x + 2 * k2x + 2 * k3x + k4x) / 6.
+ yi += ds * (k1y + 2 * k2y + 2 * k3y + k4y) / 6.
+ if not check(xi, yi):
+ break
+ stotal += ds
+ new_xb, new_yb = self.blank_pos(xi, yi)
+ if new_xb != xb or new_yb != yb:
+ if self.blank[new_yb, new_xb] == 0:
+ self.blank[new_yb, new_xb] = 1
+ xb_changes.append(new_xb)
+ yb_changes.append(new_yb)
+ xb = new_xb
+ yb = new_yb
+ else:
+ break
+ if stotal > 2:
+ break
+ return stotal, xf_traj, yf_traj
+
+ sf, xf_traj, yf_traj = rk4(x0, y0, f)
+ sb, xb_traj, yb_traj = rk4(x0, y0, g)
+ stotal = sf + sb
+ x_traj = xb_traj[::-1] + xf_traj[1:]
+ y_traj = yb_traj[::-1] + yf_traj[1:]
+
+ if len(x_traj) < 1:
+ return None
+ if stotal > .2:
+ initxb, inityb = self.blank_pos(x0, y0)
+ self.blank[inityb, initxb] = 1
+ return x_traj, y_traj
+ else:
+ for xb, yb in zip(xb_changes, yb_changes):
+ self.blank[yb, xb] = 0
+ return None
+
+ def traj(self, xb, yb):
+ """
+ Integrate trajectories
+
+ :param (int) xb: results of passing xi through self.blank_pos
+ :param (int) xy: results of passing yi through self.blank_pos
+
+ Calculate each trajectory based on rk4 integrate method.
+ """
+
+ if xb < 0 or xb >= self.density or yb < 0 or yb >= self.density:
+ return
+ if self.blank[yb, xb] == 0:
+ t = self.rk4_integrate(xb * self.spacing_x, yb * self.spacing_y)
+ if t is not None:
+ self.trajectories.append(t)
+
+ def get_streamlines(self):
+ """
+ Get streamlines by building trajectory set.
+ """
+ for indent in range(self.density // 2):
+ for xi in range(self.density - 2 * indent):
+ self.traj(xi + indent, indent)
+ self.traj(xi + indent, self.density - 1 - indent)
+ self.traj(indent, xi + indent)
+ self.traj(self.density - 1 - indent, xi + indent)
+
+ self.st_x = [np.array(t[0]) * self.delta_x + self.x[0] for t in
+ self.trajectories]
+ self.st_y = [np.array(t[1]) * self.delta_y + self.y[0] for t in
+ self.trajectories]
+
+ for index in range(len(self.st_x)):
+ self.st_x[index] = self.st_x[index].tolist()
+ self.st_x[index].append(np.nan)
+
+ for index in range(len(self.st_y)):
+ self.st_y[index] = self.st_y[index].tolist()
+ self.st_y[index].append(np.nan)
+
+ def get_streamline_arrows(self):
+ """
+ Makes an arrow for each streamline.
+
+ Gets angle of streamline at 1/3 mark and creates arrow coordinates
+ based off of user defined angle and arrow_scale.
+
+ :param (array) st_x: x-values for all streamlines
+ :param (array) st_y: y-values for all streamlines
+ :param (angle in radians) angle: angle of arrowhead. Default = pi/9
+ :param (float in [0,1]) arrow_scale: value to scale length of arrowhead
+ Default = .09
+ :rtype (list, list) arrows_x: x-values to create arrowhead and
+ arrows_y: y-values to create arrowhead
+ """
+ arrow_end_x = np.empty((len(self.st_x)))
+ arrow_end_y = np.empty((len(self.st_y)))
+ arrow_start_x = np.empty((len(self.st_x)))
+ arrow_start_y = np.empty((len(self.st_y)))
+ for index in range(len(self.st_x)):
+ arrow_end_x[index] = (self.st_x[index]
+ [int(len(self.st_x[index]) / 3)])
+ arrow_start_x[index] = (self.st_x[index]
+ [(int(len(self.st_x[index]) / 3)) - 1])
+ arrow_end_y[index] = (self.st_y[index]
+ [int(len(self.st_y[index]) / 3)])
+ arrow_start_y[index] = (self.st_y[index]
+ [(int(len(self.st_y[index]) / 3)) - 1])
+
+ dif_x = arrow_end_x - arrow_start_x
+ dif_y = arrow_end_y - arrow_start_y
+
+ streamline_ang = np.arctan(dif_y / dif_x)
+
+ ang1 = streamline_ang + (self.angle)
+ ang2 = streamline_ang - (self.angle)
+
+ seg1_x = np.cos(ang1) * self.arrow_scale
+ seg1_y = np.sin(ang1) * self.arrow_scale
+ seg2_x = np.cos(ang2) * self.arrow_scale
+ seg2_y = np.sin(ang2) * self.arrow_scale
+
+ point1_x = np.empty((len(dif_x)))
+ point1_y = np.empty((len(dif_y)))
+ point2_x = np.empty((len(dif_x)))
+ point2_y = np.empty((len(dif_y)))
+
+ for index in range(len(dif_x)):
+ if dif_x[index] >= 0:
+ point1_x[index] = arrow_end_x[index] - seg1_x[index]
+ point1_y[index] = arrow_end_y[index] - seg1_y[index]
+ point2_x[index] = arrow_end_x[index] - seg2_x[index]
+ point2_y[index] = arrow_end_y[index] - seg2_y[index]
+ else:
+ point1_x[index] = arrow_end_x[index] + seg1_x[index]
+ point1_y[index] = arrow_end_y[index] + seg1_y[index]
+ point2_x[index] = arrow_end_x[index] + seg2_x[index]
+ point2_y[index] = arrow_end_y[index] + seg2_y[index]
+
+ space = np.empty((len(point1_x)))
+ space[:] = np.nan
+
+ # Combine arrays into matrix
+ arrows_x = np.matrix([point1_x, arrow_end_x, point2_x, space])
+ arrows_x = np.array(arrows_x)
+ arrows_x = arrows_x.flatten('F')
+ arrows_x = arrows_x.tolist()
+
+ # Combine arrays into matrix
+ arrows_y = np.matrix([point1_y, arrow_end_y, point2_y, space])
+ arrows_y = np.array(arrows_y)
+ arrows_y = arrows_y.flatten('F')
+ arrows_y = arrows_y.tolist()
+
+ return arrows_x, arrows_y
+
+ def sum_streamlines(self):
+ """
+ Makes all streamlines readable as a single trace.
+
+ :rtype (list, list): streamline_x: all x values for each streamline
+ combined into single list and streamline_y: all y values for each
+ streamline combined into single list
+ """
+ streamline_x = sum(self.st_x, [])
+ streamline_y = sum(self.st_y, [])
+ return streamline_x, streamline_y
diff --git a/plotly/figure_factory/_table.py b/plotly/figure_factory/_table.py
new file mode 100644
index 00000000000..001dc9ea0ec
--- /dev/null
+++ b/plotly/figure_factory/_table.py
@@ -0,0 +1,232 @@
+from __future__ import absolute_import
+
+from plotly import exceptions, optional_imports
+from plotly.graph_objs import graph_objs
+
+pd = optional_imports.get_module('pandas')
+
+
+def validate_table(table_text, font_colors):
+ """
+ Table-specific validations
+
+ Check that font_colors is supplied correctly (1, 3, or len(text)
+ colors).
+
+ :raises: (PlotlyError) If font_colors is supplied incorretly.
+
+ See FigureFactory.create_table() for params
+ """
+ font_colors_len_options = [1, 3, len(table_text)]
+ if len(font_colors) not in font_colors_len_options:
+ raise exceptions.PlotlyError("Oops, font_colors should be a list "
+ "of length 1, 3 or len(text)")
+
+
+def create_table(table_text, colorscale=None, font_colors=None,
+ index=False, index_title='', annotation_offset=.45,
+ height_constant=30, hoverinfo='none', **kwargs):
+ """
+ BETA function that creates data tables
+
+ :param (pandas.Dataframe | list[list]) text: data for table.
+ :param (str|list[list]) colorscale: Colorscale for table where the
+ color at value 0 is the header color, .5 is the first table color
+ and 1 is the second table color. (Set .5 and 1 to avoid the striped
+ table effect). Default=[[0, '#66b2ff'], [.5, '#d9d9d9'],
+ [1, '#ffffff']]
+ :param (list) font_colors: Color for fonts in table. Can be a single
+ color, three colors, or a color for each row in the table.
+ Default=['#000000'] (black text for the entire table)
+ :param (int) height_constant: Constant multiplied by # of rows to
+ create table height. Default=30.
+ :param (bool) index: Create (header-colored) index column index from
+ Pandas dataframe or list[0] for each list in text. Default=False.
+ :param (string) index_title: Title for index column. Default=''.
+ :param kwargs: kwargs passed through plotly.graph_objs.Heatmap.
+ These kwargs describe other attributes about the annotated Heatmap
+ trace such as the colorscale. For more information on valid kwargs
+ call help(plotly.graph_objs.Heatmap)
+
+ Example 1: Simple Plotly Table
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_table
+
+ text = [['Country', 'Year', 'Population'],
+ ['US', 2000, 282200000],
+ ['Canada', 2000, 27790000],
+ ['US', 2010, 309000000],
+ ['Canada', 2010, 34000000]]
+
+ table = create_table(text)
+ py.iplot(table)
+ ```
+
+ Example 2: Table with Custom Coloring
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_table
+
+ text = [['Country', 'Year', 'Population'],
+ ['US', 2000, 282200000],
+ ['Canada', 2000, 27790000],
+ ['US', 2010, 309000000],
+ ['Canada', 2010, 34000000]]
+
+ table = create_table(text,
+ colorscale=[[0, '#000000'],
+ [.5, '#80beff'],
+ [1, '#cce5ff']],
+ font_colors=['#ffffff', '#000000',
+ '#000000'])
+ py.iplot(table)
+ ```
+ Example 3: Simple Plotly Table with Pandas
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_table
+
+ import pandas as pd
+
+ df = pd.read_csv('http://www.stat.ubc.ca/~jenny/notOcto/STAT545A/examples/gapminder/data/gapminderDataFiveYear.txt', sep='\t')
+ df_p = df[0:25]
+
+ table_simple = create_table(df_p)
+ py.iplot(table_simple)
+ ```
+ """
+
+ # Avoiding mutables in the call signature
+ colorscale = \
+ colorscale if colorscale is not None else [[0, '#00083e'],
+ [.5, '#ededee'],
+ [1, '#ffffff']]
+ font_colors = font_colors if font_colors is not None else ['#ffffff',
+ '#000000',
+ '#000000']
+
+ validate_table(table_text, font_colors)
+ table_matrix = _Table(table_text, colorscale, font_colors, index,
+ index_title, annotation_offset,
+ **kwargs).get_table_matrix()
+ annotations = _Table(table_text, colorscale, font_colors, index,
+ index_title, annotation_offset,
+ **kwargs).make_table_annotations()
+
+ trace = dict(type='heatmap', z=table_matrix, opacity=.75,
+ colorscale=colorscale, showscale=False,
+ hoverinfo=hoverinfo, **kwargs)
+
+ data = [trace]
+ layout = dict(annotations=annotations,
+ height=len(table_matrix) * height_constant + 50,
+ margin=dict(t=0, b=0, r=0, l=0),
+ yaxis=dict(autorange='reversed', zeroline=False,
+ gridwidth=2, ticks='', dtick=1, tick0=.5,
+ showticklabels=False),
+ xaxis=dict(zeroline=False, gridwidth=2, ticks='',
+ dtick=1, tick0=-0.5, showticklabels=False))
+ return graph_objs.Figure(data=data, layout=layout)
+
+
+class _Table(object):
+ """
+ Refer to TraceFactory.create_table() for docstring
+ """
+ def __init__(self, table_text, colorscale, font_colors, index,
+ index_title, annotation_offset, **kwargs):
+ if pd and isinstance(table_text, pd.DataFrame):
+ headers = table_text.columns.tolist()
+ table_text_index = table_text.index.tolist()
+ table_text = table_text.values.tolist()
+ table_text.insert(0, headers)
+ if index:
+ table_text_index.insert(0, index_title)
+ for i in range(len(table_text)):
+ table_text[i].insert(0, table_text_index[i])
+ self.table_text = table_text
+ self.colorscale = colorscale
+ self.font_colors = font_colors
+ self.index = index
+ self.annotation_offset = annotation_offset
+ self.x = range(len(table_text[0]))
+ self.y = range(len(table_text))
+
+ def get_table_matrix(self):
+ """
+ Create z matrix to make heatmap with striped table coloring
+
+ :rtype (list[list]) table_matrix: z matrix to make heatmap with striped
+ table coloring.
+ """
+ header = [0] * len(self.table_text[0])
+ odd_row = [.5] * len(self.table_text[0])
+ even_row = [1] * len(self.table_text[0])
+ table_matrix = [None] * len(self.table_text)
+ table_matrix[0] = header
+ for i in range(1, len(self.table_text), 2):
+ table_matrix[i] = odd_row
+ for i in range(2, len(self.table_text), 2):
+ table_matrix[i] = even_row
+ if self.index:
+ for array in table_matrix:
+ array[0] = 0
+ return table_matrix
+
+ def get_table_font_color(self):
+ """
+ Fill font-color array.
+
+ Table text color can vary by row so this extends a single color or
+ creates an array to set a header color and two alternating colors to
+ create the striped table pattern.
+
+ :rtype (list[list]) all_font_colors: list of font colors for each row
+ in table.
+ """
+ if len(self.font_colors) == 1:
+ all_font_colors = self.font_colors*len(self.table_text)
+ elif len(self.font_colors) == 3:
+ all_font_colors = list(range(len(self.table_text)))
+ all_font_colors[0] = self.font_colors[0]
+ for i in range(1, len(self.table_text), 2):
+ all_font_colors[i] = self.font_colors[1]
+ for i in range(2, len(self.table_text), 2):
+ all_font_colors[i] = self.font_colors[2]
+ elif len(self.font_colors) == len(self.table_text):
+ all_font_colors = self.font_colors
+ else:
+ all_font_colors = ['#000000']*len(self.table_text)
+ return all_font_colors
+
+ def make_table_annotations(self):
+ """
+ Generate annotations to fill in table text
+
+ :rtype (list) annotations: list of annotations for each cell of the
+ table.
+ """
+ table_matrix = _Table.get_table_matrix(self)
+ all_font_colors = _Table.get_table_font_color(self)
+ annotations = []
+ for n, row in enumerate(self.table_text):
+ for m, val in enumerate(row):
+ # Bold text in header and index
+ format_text = ('' + str(val) + '' if n == 0 or
+ self.index and m < 1 else str(val))
+ # Match font color of index to font color of header
+ font_color = (self.font_colors[0] if self.index and m == 0
+ else all_font_colors[n])
+ annotations.append(
+ graph_objs.Annotation(
+ text=format_text,
+ x=self.x[m] - self.annotation_offset,
+ y=self.y[n],
+ xref='x1',
+ yref='y1',
+ align="left",
+ xanchor="left",
+ font=dict(color=font_color),
+ showarrow=False))
+ return annotations
diff --git a/plotly/figure_factory/_trisurf.py b/plotly/figure_factory/_trisurf.py
new file mode 100644
index 00000000000..d2b58420471
--- /dev/null
+++ b/plotly/figure_factory/_trisurf.py
@@ -0,0 +1,488 @@
+from __future__ import absolute_import
+
+from plotly import colors, exceptions, optional_imports
+from plotly.graph_objs import graph_objs
+
+np = optional_imports.get_module('numpy')
+
+
+def map_face2color(face, colormap, scale, vmin, vmax):
+ """
+ Normalize facecolor values by vmin/vmax and return rgb-color strings
+
+ This function takes a tuple color along with a colormap and a minimum
+ (vmin) and maximum (vmax) range of possible mean distances for the
+ given parametrized surface. It returns an rgb color based on the mean
+ distance between vmin and vmax
+
+ """
+ if vmin >= vmax:
+ raise exceptions.PlotlyError("Incorrect relation between vmin "
+ "and vmax. The vmin value cannot be "
+ "bigger than or equal to the value "
+ "of vmax.")
+ if len(colormap) == 1:
+ # color each triangle face with the same color in colormap
+ face_color = colormap[0]
+ face_color = colors.convert_to_RGB_255(face_color)
+ face_color = colors.label_rgb(face_color)
+ return face_color
+ if face == vmax:
+ # pick last color in colormap
+ face_color = colormap[-1]
+ face_color = colors.convert_to_RGB_255(face_color)
+ face_color = colors.label_rgb(face_color)
+ return face_color
+ else:
+ if scale is None:
+ # find the normalized distance t of a triangle face between
+ # vmin and vmax where the distance is between 0 and 1
+ t = (face - vmin) / float((vmax - vmin))
+ low_color_index = int(t / (1./(len(colormap) - 1)))
+
+ face_color = colors.find_intermediate_color(
+ colormap[low_color_index],
+ colormap[low_color_index + 1],
+ t * (len(colormap) - 1) - low_color_index
+ )
+
+ face_color = colors.convert_to_RGB_255(face_color)
+ face_color = colors.label_rgb(face_color)
+ else:
+ # find the face color for a non-linearly interpolated scale
+ t = (face - vmin) / float((vmax - vmin))
+
+ low_color_index = 0
+ for k in range(len(scale) - 1):
+ if scale[k] <= t < scale[k+1]:
+ break
+ low_color_index += 1
+
+ low_scale_val = scale[low_color_index]
+ high_scale_val = scale[low_color_index + 1]
+
+ face_color = colors.find_intermediate_color(
+ colormap[low_color_index],
+ colormap[low_color_index + 1],
+ (t - low_scale_val)/(high_scale_val - low_scale_val)
+ )
+
+ face_color = colors.convert_to_RGB_255(face_color)
+ face_color = colors.label_rgb(face_color)
+ return face_color
+
+
+def trisurf(x, y, z, simplices, show_colorbar, edges_color, scale,
+ colormap=None, color_func=None, plot_edges=False, x_edge=None,
+ y_edge=None, z_edge=None, facecolor=None):
+ """
+ Refer to FigureFactory.create_trisurf() for docstring
+ """
+ # numpy import check
+ if not np:
+ raise ImportError("FigureFactory._trisurf() requires "
+ "numpy imported.")
+ points3D = np.vstack((x, y, z)).T
+ simplices = np.atleast_2d(simplices)
+
+ # vertices of the surface triangles
+ tri_vertices = points3D[simplices]
+
+ # Define colors for the triangle faces
+ if color_func is None:
+ # mean values of z-coordinates of triangle vertices
+ mean_dists = tri_vertices[:, :, 2].mean(-1)
+ elif isinstance(color_func, (list, np.ndarray)):
+ # Pre-computed list / array of values to map onto color
+ if len(color_func) != len(simplices):
+ raise ValueError("If color_func is a list/array, it must "
+ "be the same length as simplices.")
+
+ # convert all colors in color_func to rgb
+ for index in range(len(color_func)):
+ if isinstance(color_func[index], str):
+ if '#' in color_func[index]:
+ foo = colors.hex_to_rgb(color_func[index])
+ color_func[index] = colors.label_rgb(foo)
+
+ if isinstance(color_func[index], tuple):
+ foo = colors.convert_to_RGB_255(color_func[index])
+ color_func[index] = colors.label_rgb(foo)
+
+ mean_dists = np.asarray(color_func)
+ else:
+ # apply user inputted function to calculate
+ # custom coloring for triangle vertices
+ mean_dists = []
+ for triangle in tri_vertices:
+ dists = []
+ for vertex in triangle:
+ dist = color_func(vertex[0], vertex[1], vertex[2])
+ dists.append(dist)
+ mean_dists.append(np.mean(dists))
+ mean_dists = np.asarray(mean_dists)
+
+ # Check if facecolors are already strings and can be skipped
+ if isinstance(mean_dists[0], str):
+ facecolor = mean_dists
+ else:
+ min_mean_dists = np.min(mean_dists)
+ max_mean_dists = np.max(mean_dists)
+
+ if facecolor is None:
+ facecolor = []
+ for index in range(len(mean_dists)):
+ color = map_face2color(mean_dists[index], colormap, scale,
+ min_mean_dists, max_mean_dists)
+ facecolor.append(color)
+
+ # Make sure facecolor is a list so output is consistent across Pythons
+ facecolor = np.asarray(facecolor)
+ ii, jj, kk = simplices.T
+
+ triangles = graph_objs.Mesh3d(x=x, y=y, z=z, facecolor=facecolor,
+ i=ii, j=jj, k=kk, name='')
+
+ mean_dists_are_numbers = not isinstance(mean_dists[0], str)
+
+ if mean_dists_are_numbers and show_colorbar is True:
+ # make a colorscale from the colors
+ colorscale = colors.make_colorscale(colormap, scale)
+ colorscale = colors.convert_colorscale_to_rgb(colorscale)
+
+ colorbar = graph_objs.Scatter3d(
+ x=x[:1],
+ y=y[:1],
+ z=z[:1],
+ mode='markers',
+ marker=dict(
+ size=0.1,
+ color=[min_mean_dists, max_mean_dists],
+ colorscale=colorscale,
+ showscale=True),
+ hoverinfo='None',
+ showlegend=False
+ )
+
+ # the triangle sides are not plotted
+ if plot_edges is False:
+ if mean_dists_are_numbers and show_colorbar is True:
+ return graph_objs.Data([triangles, colorbar])
+ else:
+ return graph_objs.Data([triangles])
+
+ # define the lists x_edge, y_edge and z_edge, of x, y, resp z
+ # coordinates of edge end points for each triangle
+ # None separates data corresponding to two consecutive triangles
+ is_none = [ii is None for ii in [x_edge, y_edge, z_edge]]
+ if any(is_none):
+ if not all(is_none):
+ raise ValueError("If any (x_edge, y_edge, z_edge) is None, "
+ "all must be None")
+ else:
+ x_edge = []
+ y_edge = []
+ z_edge = []
+
+ # Pull indices we care about, then add a None column to separate tris
+ ixs_triangles = [0, 1, 2, 0]
+ pull_edges = tri_vertices[:, ixs_triangles, :]
+ x_edge_pull = np.hstack([pull_edges[:, :, 0],
+ np.tile(None, [pull_edges.shape[0], 1])])
+ y_edge_pull = np.hstack([pull_edges[:, :, 1],
+ np.tile(None, [pull_edges.shape[0], 1])])
+ z_edge_pull = np.hstack([pull_edges[:, :, 2],
+ np.tile(None, [pull_edges.shape[0], 1])])
+
+ # Now unravel the edges into a 1-d vector for plotting
+ x_edge = np.hstack([x_edge, x_edge_pull.reshape([1, -1])[0]])
+ y_edge = np.hstack([y_edge, y_edge_pull.reshape([1, -1])[0]])
+ z_edge = np.hstack([z_edge, z_edge_pull.reshape([1, -1])[0]])
+
+ if not (len(x_edge) == len(y_edge) == len(z_edge)):
+ raise exceptions.PlotlyError("The lengths of x_edge, y_edge and "
+ "z_edge are not the same.")
+
+ # define the lines for plotting
+ lines = graph_objs.Scatter3d(
+ x=x_edge, y=y_edge, z=z_edge, mode='lines',
+ line=graph_objs.Line(
+ color=edges_color,
+ width=1.5
+ ),
+ showlegend=False
+ )
+
+ if mean_dists_are_numbers and show_colorbar is True:
+ return graph_objs.Data([triangles, lines, colorbar])
+ else:
+ return graph_objs.Data([triangles, lines])
+
+
+def create_trisurf(x, y, z, simplices, colormap=None, show_colorbar=True,
+ scale=None, color_func=None, title='Trisurf Plot',
+ plot_edges=True, showbackground=True,
+ backgroundcolor='rgb(230, 230, 230)',
+ gridcolor='rgb(255, 255, 255)',
+ zerolinecolor='rgb(255, 255, 255)',
+ edges_color='rgb(50, 50, 50)',
+ height=800, width=800,
+ aspectratio=None):
+ """
+ Returns figure for a triangulated surface plot
+
+ :param (array) x: data values of x in a 1D array
+ :param (array) y: data values of y in a 1D array
+ :param (array) z: data values of z in a 1D array
+ :param (array) simplices: an array of shape (ntri, 3) where ntri is
+ the number of triangles in the triangularization. Each row of the
+ array contains the indicies of the verticies of each triangle
+ :param (str|tuple|list) colormap: either a plotly scale name, an rgb
+ or hex color, a color tuple or a list of colors. An rgb color is
+ of the form 'rgb(x, y, z)' where x, y, z belong to the interval
+ [0, 255] and a color tuple is a tuple of the form (a, b, c) where
+ a, b and c belong to [0, 1]. If colormap is a list, it must
+ contain the valid color types aforementioned as its members
+ :param (bool) show_colorbar: determines if colorbar is visible
+ :param (list|array) scale: sets the scale values to be used if a non-
+ linearly interpolated colormap is desired. If left as None, a
+ linear interpolation between the colors will be excecuted
+ :param (function|list) color_func: The parameter that determines the
+ coloring of the surface. Takes either a function with 3 arguments
+ x, y, z or a list/array of color values the same length as
+ simplices. If None, coloring will only depend on the z axis
+ :param (str) title: title of the plot
+ :param (bool) plot_edges: determines if the triangles on the trisurf
+ are visible
+ :param (bool) showbackground: makes background in plot visible
+ :param (str) backgroundcolor: color of background. Takes a string of
+ the form 'rgb(x,y,z)' x,y,z are between 0 and 255 inclusive
+ :param (str) gridcolor: color of the gridlines besides the axes. Takes
+ a string of the form 'rgb(x,y,z)' x,y,z are between 0 and 255
+ inclusive
+ :param (str) zerolinecolor: color of the axes. Takes a string of the
+ form 'rgb(x,y,z)' x,y,z are between 0 and 255 inclusive
+ :param (str) edges_color: color of the edges, if plot_edges is True
+ :param (int|float) height: the height of the plot (in pixels)
+ :param (int|float) width: the width of the plot (in pixels)
+ :param (dict) aspectratio: a dictionary of the aspect ratio values for
+ the x, y and z axes. 'x', 'y' and 'z' take (int|float) values
+
+ Example 1: Sphere
+ ```
+ # Necessary Imports for Trisurf
+ import numpy as np
+ from scipy.spatial import Delaunay
+
+ import plotly.plotly as py
+ from plotly.figure_factory import create_trisurf
+ from plotly.graph_objs import graph_objs
+
+ # Make data for plot
+ u = np.linspace(0, 2*np.pi, 20)
+ v = np.linspace(0, np.pi, 20)
+ u,v = np.meshgrid(u,v)
+ u = u.flatten()
+ v = v.flatten()
+
+ x = np.sin(v)*np.cos(u)
+ y = np.sin(v)*np.sin(u)
+ z = np.cos(v)
+
+ points2D = np.vstack([u,v]).T
+ tri = Delaunay(points2D)
+ simplices = tri.simplices
+
+ # Create a figure
+ fig1 = create_trisurf(x=x, y=y, z=z, colormap="Rainbow",
+ simplices=simplices)
+ # Plot the data
+ py.iplot(fig1, filename='trisurf-plot-sphere')
+ ```
+
+ Example 2: Torus
+ ```
+ # Necessary Imports for Trisurf
+ import numpy as np
+ from scipy.spatial import Delaunay
+
+ import plotly.plotly as py
+ from plotly.figure_factory import create_trisurf
+ from plotly.graph_objs import graph_objs
+
+ # Make data for plot
+ u = np.linspace(0, 2*np.pi, 20)
+ v = np.linspace(0, 2*np.pi, 20)
+ u,v = np.meshgrid(u,v)
+ u = u.flatten()
+ v = v.flatten()
+
+ x = (3 + (np.cos(v)))*np.cos(u)
+ y = (3 + (np.cos(v)))*np.sin(u)
+ z = np.sin(v)
+
+ points2D = np.vstack([u,v]).T
+ tri = Delaunay(points2D)
+ simplices = tri.simplices
+
+ # Create a figure
+ fig1 = create_trisurf(x=x, y=y, z=z, colormap="Viridis",
+ simplices=simplices)
+ # Plot the data
+ py.iplot(fig1, filename='trisurf-plot-torus')
+ ```
+
+ Example 3: Mobius Band
+ ```
+ # Necessary Imports for Trisurf
+ import numpy as np
+ from scipy.spatial import Delaunay
+
+ import plotly.plotly as py
+ from plotly.figure_factory import create_trisurf
+ from plotly.graph_objs import graph_objs
+
+ # Make data for plot
+ u = np.linspace(0, 2*np.pi, 24)
+ v = np.linspace(-1, 1, 8)
+ u,v = np.meshgrid(u,v)
+ u = u.flatten()
+ v = v.flatten()
+
+ tp = 1 + 0.5*v*np.cos(u/2.)
+ x = tp*np.cos(u)
+ y = tp*np.sin(u)
+ z = 0.5*v*np.sin(u/2.)
+
+ points2D = np.vstack([u,v]).T
+ tri = Delaunay(points2D)
+ simplices = tri.simplices
+
+ # Create a figure
+ fig1 = create_trisurf(x=x, y=y, z=z, colormap=[(0.2, 0.4, 0.6), (1, 1, 1)],
+ simplices=simplices)
+ # Plot the data
+ py.iplot(fig1, filename='trisurf-plot-mobius-band')
+ ```
+
+ Example 4: Using a Custom Colormap Function with Light Cone
+ ```
+ # Necessary Imports for Trisurf
+ import numpy as np
+ from scipy.spatial import Delaunay
+
+ import plotly.plotly as py
+ from plotly.figure_factory import create_trisurf
+ from plotly.graph_objs import graph_objs
+
+ # Make data for plot
+ u=np.linspace(-np.pi, np.pi, 30)
+ v=np.linspace(-np.pi, np.pi, 30)
+ u,v=np.meshgrid(u,v)
+ u=u.flatten()
+ v=v.flatten()
+
+ x = u
+ y = u*np.cos(v)
+ z = u*np.sin(v)
+
+ points2D = np.vstack([u,v]).T
+ tri = Delaunay(points2D)
+ simplices = tri.simplices
+
+ # Define distance function
+ def dist_origin(x, y, z):
+ return np.sqrt((1.0 * x)**2 + (1.0 * y)**2 + (1.0 * z)**2)
+
+ # Create a figure
+ fig1 = create_trisurf(x=x, y=y, z=z,
+ colormap=['#FFFFFF', '#E4FFFE',
+ '#A4F6F9', '#FF99FE',
+ '#BA52ED'],
+ scale=[0, 0.6, 0.71, 0.89, 1],
+ simplices=simplices,
+ color_func=dist_origin)
+ # Plot the data
+ py.iplot(fig1, filename='trisurf-plot-custom-coloring')
+ ```
+
+ Example 5: Enter color_func as a list of colors
+ ```
+ # Necessary Imports for Trisurf
+ import numpy as np
+ from scipy.spatial import Delaunay
+ import random
+
+ import plotly.plotly as py
+ from plotly.figure_factory import create_trisurf
+ from plotly.graph_objs import graph_objs
+
+ # Make data for plot
+ u=np.linspace(-np.pi, np.pi, 30)
+ v=np.linspace(-np.pi, np.pi, 30)
+ u,v=np.meshgrid(u,v)
+ u=u.flatten()
+ v=v.flatten()
+
+ x = u
+ y = u*np.cos(v)
+ z = u*np.sin(v)
+
+ points2D = np.vstack([u,v]).T
+ tri = Delaunay(points2D)
+ simplices = tri.simplices
+
+
+ colors = []
+ color_choices = ['rgb(0, 0, 0)', '#6c4774', '#d6c7dd']
+
+ for index in range(len(simplices)):
+ colors.append(random.choice(color_choices))
+
+ fig = create_trisurf(
+ x, y, z, simplices,
+ color_func=colors,
+ show_colorbar=True,
+ edges_color='rgb(2, 85, 180)',
+ title=' Modern Art'
+ )
+
+ py.iplot(fig, filename="trisurf-plot-modern-art")
+ ```
+ """
+ if aspectratio is None:
+ aspectratio = {'x': 1, 'y': 1, 'z': 1}
+
+ # Validate colormap
+ colors.validate_colors(colormap)
+ colormap, scale = colors.convert_colors_to_same_type(
+ colormap, colortype='tuple',
+ return_default_colors=True, scale=scale
+ )
+
+ data1 = trisurf(x, y, z, simplices, show_colorbar=show_colorbar,
+ color_func=color_func, colormap=colormap, scale=scale,
+ edges_color=edges_color, plot_edges=plot_edges)
+
+ axis = dict(
+ showbackground=showbackground,
+ backgroundcolor=backgroundcolor,
+ gridcolor=gridcolor,
+ zerolinecolor=zerolinecolor,
+ )
+ layout = graph_objs.Layout(
+ title=title,
+ width=width,
+ height=height,
+ scene=graph_objs.Scene(
+ xaxis=graph_objs.XAxis(axis),
+ yaxis=graph_objs.YAxis(axis),
+ zaxis=graph_objs.ZAxis(axis),
+ aspectratio=dict(
+ x=aspectratio['x'],
+ y=aspectratio['y'],
+ z=aspectratio['z']),
+ )
+ )
+
+ return graph_objs.Figure(data=data1, layout=layout)
diff --git a/plotly/figure_factory/_violin.py b/plotly/figure_factory/_violin.py
new file mode 100644
index 00000000000..d501105482e
--- /dev/null
+++ b/plotly/figure_factory/_violin.py
@@ -0,0 +1,627 @@
+from __future__ import absolute_import
+
+from numbers import Number
+
+from plotly import exceptions, optional_imports
+from plotly.figure_factory import utils
+from plotly.graph_objs import graph_objs
+from plotly.tools import make_subplots
+
+pd = optional_imports.get_module('pandas')
+np = optional_imports.get_module('numpy')
+scipy_stats = optional_imports.get_module('scipy.stats')
+
+
+def calc_stats(data):
+ """
+ Calculate statistics for use in violin plot.
+ """
+ x = np.asarray(data, np.float)
+ vals_min = np.min(x)
+ vals_max = np.max(x)
+ q2 = np.percentile(x, 50, interpolation='linear')
+ q1 = np.percentile(x, 25, interpolation='lower')
+ q3 = np.percentile(x, 75, interpolation='higher')
+ iqr = q3 - q1
+ whisker_dist = 1.5 * iqr
+
+ # in order to prevent drawing whiskers outside the interval
+ # of data one defines the whisker positions as:
+ d1 = np.min(x[x >= (q1 - whisker_dist)])
+ d2 = np.max(x[x <= (q3 + whisker_dist)])
+ return {
+ 'min': vals_min,
+ 'max': vals_max,
+ 'q1': q1,
+ 'q2': q2,
+ 'q3': q3,
+ 'd1': d1,
+ 'd2': d2
+ }
+
+
+def make_half_violin(x, y, fillcolor='#1f77b4', linecolor='rgb(0, 0, 0)'):
+ """
+ Produces a sideways probability distribution fig violin plot.
+ """
+ text = ['(pdf(y), y)=(' + '{:0.2f}'.format(x[i]) +
+ ', ' + '{:0.2f}'.format(y[i]) + ')'
+ for i in range(len(x))]
+
+ return graph_objs.Scatter(
+ x=x,
+ y=y,
+ mode='lines',
+ name='',
+ text=text,
+ fill='tonextx',
+ fillcolor=fillcolor,
+ line=graph_objs.Line(width=0.5, color=linecolor, shape='spline'),
+ hoverinfo='text',
+ opacity=0.5
+ )
+
+
+def make_violin_rugplot(vals, pdf_max, distance, color='#1f77b4'):
+ """
+ Returns a rugplot fig for a violin plot.
+ """
+ return graph_objs.Scatter(
+ y=vals,
+ x=[-pdf_max-distance]*len(vals),
+ marker=graph_objs.Marker(
+ color=color,
+ symbol='line-ew-open'
+ ),
+ mode='markers',
+ name='',
+ showlegend=False,
+ hoverinfo='y'
+ )
+
+
+def make_non_outlier_interval(d1, d2):
+ """
+ Returns the scatterplot fig of most of a violin plot.
+ """
+ return graph_objs.Scatter(
+ x=[0, 0],
+ y=[d1, d2],
+ name='',
+ mode='lines',
+ line=graph_objs.Line(width=1.5,
+ color='rgb(0,0,0)')
+ )
+
+
+def make_quartiles(q1, q3):
+ """
+ Makes the upper and lower quartiles for a violin plot.
+ """
+ return graph_objs.Scatter(
+ x=[0, 0],
+ y=[q1, q3],
+ text=['lower-quartile: ' + '{:0.2f}'.format(q1),
+ 'upper-quartile: ' + '{:0.2f}'.format(q3)],
+ mode='lines',
+ line=graph_objs.Line(
+ width=4,
+ color='rgb(0,0,0)'
+ ),
+ hoverinfo='text'
+ )
+
+
+def make_median(q2):
+ """
+ Formats the 'median' hovertext for a violin plot.
+ """
+ return graph_objs.Scatter(
+ x=[0],
+ y=[q2],
+ text=['median: ' + '{:0.2f}'.format(q2)],
+ mode='markers',
+ marker=dict(symbol='square',
+ color='rgb(255,255,255)'),
+ hoverinfo='text'
+ )
+
+
+def make_XAxis(xaxis_title, xaxis_range):
+ """
+ Makes the x-axis for a violin plot.
+ """
+ xaxis = graph_objs.XAxis(title=xaxis_title,
+ range=xaxis_range,
+ showgrid=False,
+ zeroline=False,
+ showline=False,
+ mirror=False,
+ ticks='',
+ showticklabels=False)
+ return xaxis
+
+
+def make_YAxis(yaxis_title):
+ """
+ Makes the y-axis for a violin plot.
+ """
+ yaxis = graph_objs.YAxis(title=yaxis_title,
+ showticklabels=True,
+ autorange=True,
+ ticklen=4,
+ showline=True,
+ zeroline=False,
+ showgrid=False,
+ mirror=False)
+ return yaxis
+
+
+def violinplot(vals, fillcolor='#1f77b4', rugplot=True):
+ """
+ Refer to FigureFactory.create_violin() for docstring.
+ """
+ vals = np.asarray(vals, np.float)
+ # summary statistics
+ vals_min = calc_stats(vals)['min']
+ vals_max = calc_stats(vals)['max']
+ q1 = calc_stats(vals)['q1']
+ q2 = calc_stats(vals)['q2']
+ q3 = calc_stats(vals)['q3']
+ d1 = calc_stats(vals)['d1']
+ d2 = calc_stats(vals)['d2']
+
+ # kernel density estimation of pdf
+ pdf = scipy_stats.gaussian_kde(vals)
+ # grid over the data interval
+ xx = np.linspace(vals_min, vals_max, 100)
+ # evaluate the pdf at the grid xx
+ yy = pdf(xx)
+ max_pdf = np.max(yy)
+ # distance from the violin plot to rugplot
+ distance = (2.0 * max_pdf)/10 if rugplot else 0
+ # range for x values in the plot
+ plot_xrange = [-max_pdf - distance - 0.1, max_pdf + 0.1]
+ plot_data = [make_half_violin(-yy, xx, fillcolor=fillcolor),
+ make_half_violin(yy, xx, fillcolor=fillcolor),
+ make_non_outlier_interval(d1, d2),
+ make_quartiles(q1, q3),
+ make_median(q2)]
+ if rugplot:
+ plot_data.append(make_violin_rugplot(vals, max_pdf, distance=distance,
+ color=fillcolor))
+ return plot_data, plot_xrange
+
+
+def violin_no_colorscale(data, data_header, group_header, colors,
+ use_colorscale, group_stats,
+ height, width, title):
+ """
+ Refer to FigureFactory.create_violin() for docstring.
+
+ Returns fig for violin plot without colorscale.
+
+ """
+
+ # collect all group names
+ group_name = []
+ for name in data[group_header]:
+ if name not in group_name:
+ group_name.append(name)
+ group_name.sort()
+
+ gb = data.groupby([group_header])
+ L = len(group_name)
+
+ fig = make_subplots(rows=1, cols=L,
+ shared_yaxes=True,
+ horizontal_spacing=0.025,
+ print_grid=False)
+ color_index = 0
+ for k, gr in enumerate(group_name):
+ vals = np.asarray(gb.get_group(gr)[data_header], np.float)
+ if color_index >= len(colors):
+ color_index = 0
+ plot_data, plot_xrange = violinplot(vals,
+ fillcolor=colors[color_index])
+ layout = graph_objs.Layout()
+
+ for item in plot_data:
+ fig.append_trace(item, 1, k + 1)
+ color_index += 1
+
+ # add violin plot labels
+ fig['layout'].update(
+ {'xaxis{}'.format(k + 1): make_XAxis(group_name[k], plot_xrange)}
+ )
+
+ # set the sharey axis style
+ fig['layout'].update({'yaxis{}'.format(1): make_YAxis('')})
+ fig['layout'].update(
+ title=title,
+ showlegend=False,
+ hovermode='closest',
+ autosize=False,
+ height=height,
+ width=width
+ )
+
+ return fig
+
+
+def violin_colorscale(data, data_header, group_header, colors, use_colorscale,
+ group_stats, height, width, title):
+ """
+ Refer to FigureFactory.create_violin() for docstring.
+
+ Returns fig for violin plot with colorscale.
+
+ """
+
+ # collect all group names
+ group_name = []
+ for name in data[group_header]:
+ if name not in group_name:
+ group_name.append(name)
+ group_name.sort()
+
+ # make sure all group names are keys in group_stats
+ for group in group_name:
+ if group not in group_stats:
+ raise exceptions.PlotlyError("All values/groups in the index "
+ "column must be represented "
+ "as a key in group_stats.")
+
+ gb = data.groupby([group_header])
+ L = len(group_name)
+
+ fig = make_subplots(rows=1, cols=L,
+ shared_yaxes=True,
+ horizontal_spacing=0.025,
+ print_grid=False)
+
+ # prepare low and high color for colorscale
+ lowcolor = utils.color_parser(colors[0], utils.unlabel_rgb)
+ highcolor = utils.color_parser(colors[1], utils.unlabel_rgb)
+
+ # find min and max values in group_stats
+ group_stats_values = []
+ for key in group_stats:
+ group_stats_values.append(group_stats[key])
+
+ max_value = max(group_stats_values)
+ min_value = min(group_stats_values)
+
+ for k, gr in enumerate(group_name):
+ vals = np.asarray(gb.get_group(gr)[data_header], np.float)
+
+ # find intermediate color from colorscale
+ intermed = (group_stats[gr] - min_value) / (max_value - min_value)
+ intermed_color = utils.find_intermediate_color(
+ lowcolor, highcolor, intermed
+ )
+
+ plot_data, plot_xrange = violinplot(
+ vals,
+ fillcolor='rgb{}'.format(intermed_color)
+ )
+ layout = graph_objs.Layout()
+
+ for item in plot_data:
+ fig.append_trace(item, 1, k + 1)
+ fig['layout'].update(
+ {'xaxis{}'.format(k + 1): make_XAxis(group_name[k], plot_xrange)}
+ )
+ # add colorbar to plot
+ trace_dummy = graph_objs.Scatter(
+ x=[0],
+ y=[0],
+ mode='markers',
+ marker=dict(
+ size=2,
+ cmin=min_value,
+ cmax=max_value,
+ colorscale=[[0, colors[0]],
+ [1, colors[1]]],
+ showscale=True),
+ showlegend=False,
+ )
+ fig.append_trace(trace_dummy, 1, L)
+
+ # set the sharey axis style
+ fig['layout'].update({'yaxis{}'.format(1): make_YAxis('')})
+ fig['layout'].update(
+ title=title,
+ showlegend=False,
+ hovermode='closest',
+ autosize=False,
+ height=height,
+ width=width
+ )
+
+ return fig
+
+
+def violin_dict(data, data_header, group_header, colors, use_colorscale,
+ group_stats, height, width, title):
+ """
+ Refer to FigureFactory.create_violin() for docstring.
+
+ Returns fig for violin plot without colorscale.
+
+ """
+
+ # collect all group names
+ group_name = []
+ for name in data[group_header]:
+ if name not in group_name:
+ group_name.append(name)
+ group_name.sort()
+
+ # check if all group names appear in colors dict
+ for group in group_name:
+ if group not in colors:
+ raise exceptions.PlotlyError("If colors is a dictionary, all "
+ "the group names must appear as "
+ "keys in colors.")
+
+ gb = data.groupby([group_header])
+ L = len(group_name)
+
+ fig = make_subplots(rows=1, cols=L,
+ shared_yaxes=True,
+ horizontal_spacing=0.025,
+ print_grid=False)
+
+ for k, gr in enumerate(group_name):
+ vals = np.asarray(gb.get_group(gr)[data_header], np.float)
+ plot_data, plot_xrange = violinplot(vals, fillcolor=colors[gr])
+ layout = graph_objs.Layout()
+
+ for item in plot_data:
+ fig.append_trace(item, 1, k + 1)
+
+ # add violin plot labels
+ fig['layout'].update(
+ {'xaxis{}'.format(k + 1): make_XAxis(group_name[k], plot_xrange)}
+ )
+
+ # set the sharey axis style
+ fig['layout'].update({'yaxis{}'.format(1): make_YAxis('')})
+ fig['layout'].update(
+ title=title,
+ showlegend=False,
+ hovermode='closest',
+ autosize=False,
+ height=height,
+ width=width
+ )
+
+ return fig
+
+
+def create_violin(data, data_header=None, group_header=None, colors=None,
+ use_colorscale=False, group_stats=None, height=450,
+ width=600, title='Violin and Rug Plot'):
+ """
+ Returns figure for a violin plot
+
+ :param (list|array) data: accepts either a list of numerical values,
+ a list of dictionaries all with identical keys and at least one
+ column of numeric values, or a pandas dataframe with at least one
+ column of numbers
+ :param (str) data_header: the header of the data column to be used
+ from an inputted pandas dataframe. Not applicable if 'data' is
+ a list of numeric values
+ :param (str) group_header: applicable if grouping data by a variable.
+ 'group_header' must be set to the name of the grouping variable.
+ :param (str|tuple|list|dict) colors: either a plotly scale name,
+ an rgb or hex color, a color tuple, a list of colors or a
+ dictionary. An rgb color is of the form 'rgb(x, y, z)' where
+ x, y and z belong to the interval [0, 255] and a color tuple is a
+ tuple of the form (a, b, c) where a, b and c belong to [0, 1].
+ If colors is a list, it must contain valid color types as its
+ members.
+ :param (bool) use_colorscale: Only applicable if grouping by another
+ variable. Will implement a colorscale based on the first 2 colors
+ of param colors. This means colors must be a list with at least 2
+ colors in it (Plotly colorscales are accepted since they map to a
+ list of two rgb colors)
+ :param (dict) group_stats: a dictioanry where each key is a unique
+ value from the group_header column in data. Each value must be a
+ number and will be used to color the violin plots if a colorscale
+ is being used
+ :param (float) height: the height of the violin plot
+ :param (float) width: the width of the violin plot
+ :param (str) title: the title of the violin plot
+
+ Example 1: Single Violin Plot
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_violin
+ from plotly.graph_objs import graph_objs
+
+ import numpy as np
+ from scipy import stats
+
+ # create list of random values
+ data_list = np.random.randn(100)
+ data_list.tolist()
+
+ # create violin fig
+ fig = create_violin(data_list, colors='#604d9e')
+
+ # plot
+ py.iplot(fig, filename='Violin Plot')
+ ```
+
+ Example 2: Multiple Violin Plots with Qualitative Coloring
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_violin
+ from plotly.graph_objs import graph_objs
+
+ import numpy as np
+ import pandas as pd
+ from scipy import stats
+
+ # create dataframe
+ np.random.seed(619517)
+ Nr=250
+ y = np.random.randn(Nr)
+ gr = np.random.choice(list("ABCDE"), Nr)
+ norm_params=[(0, 1.2), (0.7, 1), (-0.5, 1.4), (0.3, 1), (0.8, 0.9)]
+
+ for i, letter in enumerate("ABCDE"):
+ y[gr == letter] *=norm_params[i][1]+ norm_params[i][0]
+ df = pd.DataFrame(dict(Score=y, Group=gr))
+
+ # create violin fig
+ fig = create_violin(df, data_header='Score', group_header='Group',
+ height=600, width=1000)
+
+ # plot
+ py.iplot(fig, filename='Violin Plot with Coloring')
+ ```
+
+ Example 3: Violin Plots with Colorscale
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_violin
+ from plotly.graph_objs import graph_objs
+
+ import numpy as np
+ import pandas as pd
+ from scipy import stats
+
+ # create dataframe
+ np.random.seed(619517)
+ Nr=250
+ y = np.random.randn(Nr)
+ gr = np.random.choice(list("ABCDE"), Nr)
+ norm_params=[(0, 1.2), (0.7, 1), (-0.5, 1.4), (0.3, 1), (0.8, 0.9)]
+
+ for i, letter in enumerate("ABCDE"):
+ y[gr == letter] *=norm_params[i][1]+ norm_params[i][0]
+ df = pd.DataFrame(dict(Score=y, Group=gr))
+
+ # define header params
+ data_header = 'Score'
+ group_header = 'Group'
+
+ # make groupby object with pandas
+ group_stats = {}
+ groupby_data = df.groupby([group_header])
+
+ for group in "ABCDE":
+ data_from_group = groupby_data.get_group(group)[data_header]
+ # take a stat of the grouped data
+ stat = np.median(data_from_group)
+ # add to dictionary
+ group_stats[group] = stat
+
+ # create violin fig
+ fig = create_violin(df, data_header='Score', group_header='Group',
+ height=600, width=1000, use_colorscale=True,
+ group_stats=group_stats)
+
+ # plot
+ py.iplot(fig, filename='Violin Plot with Colorscale')
+ ```
+ """
+
+ # Validate colors
+ if isinstance(colors, dict):
+ valid_colors = utils.validate_colors_dict(colors, 'rgb')
+ else:
+ valid_colors = utils.validate_colors(colors, 'rgb')
+
+ # validate data and choose plot type
+ if group_header is None:
+ if isinstance(data, list):
+ if len(data) <= 0:
+ raise exceptions.PlotlyError("If data is a list, it must be "
+ "nonempty and contain either "
+ "numbers or dictionaries.")
+
+ if not all(isinstance(element, Number) for element in data):
+ raise exceptions.PlotlyError("If data is a list, it must "
+ "contain only numbers.")
+
+ if pd and isinstance(data, pd.core.frame.DataFrame):
+ if data_header is None:
+ raise exceptions.PlotlyError("data_header must be the "
+ "column name with the "
+ "desired numeric data for "
+ "the violin plot.")
+
+ data = data[data_header].values.tolist()
+
+ # call the plotting functions
+ plot_data, plot_xrange = violinplot(data, fillcolor=valid_colors[0])
+
+ layout = graph_objs.Layout(
+ title=title,
+ autosize=False,
+ font=graph_objs.Font(size=11),
+ height=height,
+ showlegend=False,
+ width=width,
+ xaxis=make_XAxis('', plot_xrange),
+ yaxis=make_YAxis(''),
+ hovermode='closest'
+ )
+ layout['yaxis'].update(dict(showline=False,
+ showticklabels=False,
+ ticks=''))
+
+ fig = graph_objs.Figure(data=graph_objs.Data(plot_data),
+ layout=layout)
+
+ return fig
+
+ else:
+ if not isinstance(data, pd.core.frame.DataFrame):
+ raise exceptions.PlotlyError("Error. You must use a pandas "
+ "DataFrame if you are using a "
+ "group header.")
+
+ if data_header is None:
+ raise exceptions.PlotlyError("data_header must be the column "
+ "name with the desired numeric "
+ "data for the violin plot.")
+
+ if use_colorscale is False:
+ if isinstance(valid_colors, dict):
+ # validate colors dict choice below
+ fig = violin_dict(
+ data, data_header, group_header, valid_colors,
+ use_colorscale, group_stats, height, width, title
+ )
+ return fig
+ else:
+ fig = violin_no_colorscale(
+ data, data_header, group_header, valid_colors,
+ use_colorscale, group_stats, height, width, title
+ )
+ return fig
+ else:
+ if isinstance(valid_colors, dict):
+ raise exceptions.PlotlyError("The colors param cannot be "
+ "a dictionary if you are "
+ "using a colorscale.")
+
+ if len(valid_colors) < 2:
+ raise exceptions.PlotlyError("colors must be a list with "
+ "at least 2 colors. A "
+ "Plotly scale is allowed.")
+
+ if not isinstance(group_stats, dict):
+ raise exceptions.PlotlyError("Your group_stats param "
+ "must be a dictionary.")
+
+ fig = violin_colorscale(
+ data, data_header, group_header, valid_colors,
+ use_colorscale, group_stats, height, width, title
+ )
+ return fig
diff --git a/plotly/figure_factory/utils.py b/plotly/figure_factory/utils.py
new file mode 100644
index 00000000000..114cd2be5d5
--- /dev/null
+++ b/plotly/figure_factory/utils.py
@@ -0,0 +1,376 @@
+from __future__ import absolute_import
+
+import decimal
+
+from plotly import exceptions
+
+DEFAULT_PLOTLY_COLORS = ['rgb(31, 119, 180)', 'rgb(255, 127, 14)',
+ 'rgb(44, 160, 44)', 'rgb(214, 39, 40)',
+ 'rgb(148, 103, 189)', 'rgb(140, 86, 75)',
+ 'rgb(227, 119, 194)', 'rgb(127, 127, 127)',
+ 'rgb(188, 189, 34)', 'rgb(23, 190, 207)']
+
+PLOTLY_SCALES = {'Greys': ['rgb(0,0,0)', 'rgb(255,255,255)'],
+ 'YlGnBu': ['rgb(8,29,88)', 'rgb(255,255,217)'],
+ 'Greens': ['rgb(0,68,27)', 'rgb(247,252,245)'],
+ 'YlOrRd': ['rgb(128,0,38)', 'rgb(255,255,204)'],
+ 'Bluered': ['rgb(0,0,255)', 'rgb(255,0,0)'],
+ 'RdBu': ['rgb(5,10,172)', 'rgb(178,10,28)'],
+ 'Reds': ['rgb(220,220,220)', 'rgb(178,10,28)'],
+ 'Blues': ['rgb(5,10,172)', 'rgb(220,220,220)'],
+ 'Picnic': ['rgb(0,0,255)', 'rgb(255,0,0)'],
+ 'Rainbow': ['rgb(150,0,90)', 'rgb(255,0,0)'],
+ 'Portland': ['rgb(12,51,131)', 'rgb(217,30,30)'],
+ 'Jet': ['rgb(0,0,131)', 'rgb(128,0,0)'],
+ 'Hot': ['rgb(0,0,0)', 'rgb(255,255,255)'],
+ 'Blackbody': ['rgb(0,0,0)', 'rgb(160,200,255)'],
+ 'Earth': ['rgb(0,0,130)', 'rgb(255,255,255)'],
+ 'Electric': ['rgb(0,0,0)', 'rgb(255,250,220)'],
+ 'Viridis': ['rgb(68,1,84)', 'rgb(253,231,37)']}
+
+
+def validate_index(index_vals):
+ """
+ Validates if a list contains all numbers or all strings
+
+ :raises: (PlotlyError) If there are any two items in the list whose
+ types differ
+ """
+ from numbers import Number
+ if isinstance(index_vals[0], Number):
+ if not all(isinstance(item, Number) for item in index_vals):
+ raise exceptions.PlotlyError("Error in indexing column. "
+ "Make sure all entries of each "
+ "column are all numbers or "
+ "all strings.")
+
+ elif isinstance(index_vals[0], str):
+ if not all(isinstance(item, str) for item in index_vals):
+ raise exceptions.PlotlyError("Error in indexing column. "
+ "Make sure all entries of each "
+ "column are all numbers or "
+ "all strings.")
+
+
+def validate_dataframe(array):
+ """
+ Validates all strings or numbers in each dataframe column
+
+ :raises: (PlotlyError) If there are any two items in any list whose
+ types differ
+ """
+ from numbers import Number
+ for vector in array:
+ if isinstance(vector[0], Number):
+ if not all(isinstance(item, Number) for item in vector):
+ raise exceptions.PlotlyError("Error in dataframe. "
+ "Make sure all entries of "
+ "each column are either "
+ "numbers or strings.")
+ elif isinstance(vector[0], str):
+ if not all(isinstance(item, str) for item in vector):
+ raise exceptions.PlotlyError("Error in dataframe. "
+ "Make sure all entries of "
+ "each column are either "
+ "numbers or strings.")
+
+
+def validate_equal_length(*args):
+ """
+ Validates that data lists or ndarrays are the same length.
+
+ :raises: (PlotlyError) If any data lists are not the same length.
+ """
+ length = len(args[0])
+ if any(len(lst) != length for lst in args):
+ raise exceptions.PlotlyError("Oops! Your data lists or ndarrays "
+ "should be the same length.")
+
+
+def validate_positive_scalars(**kwargs):
+ """
+ Validates that all values given in key/val pairs are positive.
+
+ Accepts kwargs to improve Exception messages.
+
+ :raises: (PlotlyError) If any value is < 0 or raises.
+ """
+ for key, val in kwargs.items():
+ try:
+ if val <= 0:
+ raise ValueError('{} must be > 0, got {}'.format(key, val))
+ except TypeError:
+ raise exceptions.PlotlyError('{} must be a number, got {}'
+ .format(key, val))
+
+
+def flatten(array):
+ """
+ Uses list comprehension to flatten array
+
+ :param (array): An iterable to flatten
+ :raises (PlotlyError): If iterable is not nested.
+ :rtype (list): The flattened list.
+ """
+ try:
+ return [item for sublist in array for item in sublist]
+ except TypeError:
+ raise exceptions.PlotlyError("Your data array could not be "
+ "flattened! Make sure your data is "
+ "entered as lists or ndarrays!")
+
+
+def find_intermediate_color(lowcolor, highcolor, intermed):
+ """
+ Returns the color at a given distance between two colors
+
+ This function takes two color tuples, where each element is between 0
+ and 1, along with a value 0 < intermed < 1 and returns a color that is
+ intermed-percent from lowcolor to highcolor
+
+ """
+ diff_0 = float(highcolor[0] - lowcolor[0])
+ diff_1 = float(highcolor[1] - lowcolor[1])
+ diff_2 = float(highcolor[2] - lowcolor[2])
+
+ return (lowcolor[0] + intermed * diff_0,
+ lowcolor[1] + intermed * diff_1,
+ lowcolor[2] + intermed * diff_2)
+
+
+def n_colors(lowcolor, highcolor, n_colors):
+ """
+ Splits a low and high color into a list of n_colors colors in it
+
+ Accepts two color tuples and returns a list of n_colors colors
+ which form the intermediate colors between lowcolor and highcolor
+ from linearly interpolating through RGB space
+
+ """
+ diff_0 = float(highcolor[0] - lowcolor[0])
+ incr_0 = diff_0/(n_colors - 1)
+ diff_1 = float(highcolor[1] - lowcolor[1])
+ incr_1 = diff_1/(n_colors - 1)
+ diff_2 = float(highcolor[2] - lowcolor[2])
+ incr_2 = diff_2/(n_colors - 1)
+ color_tuples = []
+
+ for index in range(n_colors):
+ new_tuple = (lowcolor[0] + (index * incr_0),
+ lowcolor[1] + (index * incr_1),
+ lowcolor[2] + (index * incr_2))
+ color_tuples.append(new_tuple)
+
+ return color_tuples
+
+
+def label_rgb(colors):
+ """
+ Takes tuple (a, b, c) and returns an rgb color 'rgb(a, b, c)'
+ """
+ return ('rgb(%s, %s, %s)' % (colors[0], colors[1], colors[2]))
+
+
+def unlabel_rgb(colors):
+ """
+ Takes rgb color(s) 'rgb(a, b, c)' and returns tuple(s) (a, b, c)
+
+ This function takes either an 'rgb(a, b, c)' color or a list of
+ such colors and returns the color tuples in tuple(s) (a, b, c)
+
+ """
+ str_vals = ''
+ for index in range(len(colors)):
+ try:
+ float(colors[index])
+ str_vals = str_vals + colors[index]
+ except ValueError:
+ if colors[index] == ',' or colors[index] == '.':
+ str_vals = str_vals + colors[index]
+
+ str_vals = str_vals + ','
+ numbers = []
+ str_num = ''
+ for char in str_vals:
+ if char != ',':
+ str_num = str_num + char
+ else:
+ numbers.append(float(str_num))
+ str_num = ''
+ return (numbers[0], numbers[1], numbers[2])
+
+
+def unconvert_from_RGB_255(colors):
+ """
+ Return a tuple where each element gets divided by 255
+
+ Takes a (list of) color tuple(s) where each element is between 0 and
+ 255. Returns the same tuples where each tuple element is normalized to
+ a value between 0 and 1
+
+ """
+ return (colors[0]/(255.0),
+ colors[1]/(255.0),
+ colors[2]/(255.0))
+
+
+def convert_to_RGB_255(colors):
+ """
+ Multiplies each element of a triplet by 255
+
+ Each coordinate of the color tuple is rounded to the nearest float and
+ then is turned into an integer. If a number is of the form x.5, then
+ if x is odd, the number rounds up to (x+1). Otherwise, it rounds down
+ to just x. This is the way rounding works in Python 3 and in current
+ statistical analysis to avoid rounding bias
+ """
+ rgb_components = []
+
+ for component in colors:
+ rounded_num = decimal.Decimal(str(component*255.0)).quantize(
+ decimal.Decimal('1'), rounding=decimal.ROUND_HALF_EVEN
+ )
+ # convert rounded number to an integer from 'Decimal' form
+ rounded_num = int(rounded_num)
+ rgb_components.append(rounded_num)
+
+ return (rgb_components[0], rgb_components[1], rgb_components[2])
+
+
+def hex_to_rgb(value):
+ """
+ Calculates rgb values from a hex color code.
+
+ :param (string) value: Hex color string
+
+ :rtype (tuple) (r_value, g_value, b_value): tuple of rgb values
+ """
+ value = value.lstrip('#')
+ hex_total_length = len(value)
+ rgb_section_length = hex_total_length // 3
+ return tuple(int(value[i:i + rgb_section_length], 16)
+ for i in range(0, hex_total_length, rgb_section_length))
+
+
+def color_parser(colors, function):
+ """
+ Takes color(s) and a function and applies the function on the color(s)
+
+ In particular, this function identifies whether the given color object
+ is an iterable or not and applies the given color-parsing function to
+ the color or iterable of colors. If given an iterable, it will only be
+ able to work with it if all items in the iterable are of the same type
+ - rgb string, hex string or tuple
+
+ """
+ from numbers import Number
+ if isinstance(colors, str):
+ return function(colors)
+
+ if isinstance(colors, tuple) and isinstance(colors[0], Number):
+ return function(colors)
+
+ if hasattr(colors, '__iter__'):
+ if isinstance(colors, tuple):
+ new_color_tuple = tuple(function(item) for item in colors)
+ return new_color_tuple
+
+ else:
+ new_color_list = [function(item) for item in colors]
+ return new_color_list
+
+
+def validate_colors(colors, colortype='tuple'):
+ """
+ Validates color(s) and returns a list of color(s) of a specified type
+ """
+ from numbers import Number
+ if colors is None:
+ colors = DEFAULT_PLOTLY_COLORS
+
+ if isinstance(colors, str):
+ if colors in PLOTLY_SCALES:
+ colors = PLOTLY_SCALES[colors]
+ elif 'rgb' in colors or '#' in colors:
+ colors = [colors]
+ else:
+ raise exceptions.PlotlyError(
+ "If your colors variable is a string, it must be a "
+ "Plotly scale, an rgb color or a hex color.")
+
+ elif isinstance(colors, tuple):
+ if isinstance(colors[0], Number):
+ colors = [colors]
+ else:
+ colors = list(colors)
+
+ # convert color elements in list to tuple color
+ for j, each_color in enumerate(colors):
+ if 'rgb' in each_color:
+ each_color = color_parser(each_color, unlabel_rgb)
+ for value in each_color:
+ if value > 255.0:
+ raise exceptions.PlotlyError(
+ "Whoops! The elements in your rgb colors "
+ "tuples cannot exceed 255.0."
+ )
+ each_color = color_parser(each_color, unconvert_from_RGB_255)
+ colors[j] = each_color
+
+ if '#' in each_color:
+ each_color = color_parser(each_color, hex_to_rgb)
+ each_color = color_parser(each_color, unconvert_from_RGB_255)
+
+ colors[j] = each_color
+
+ if isinstance(each_color, tuple):
+ for value in each_color:
+ if value > 1.0:
+ raise exceptions.PlotlyError(
+ "Whoops! The elements in your colors tuples "
+ "cannot exceed 1.0."
+ )
+ colors[j] = each_color
+
+ if colortype == 'rgb':
+ for j, each_color in enumerate(colors):
+ rgb_color = color_parser(each_color, convert_to_RGB_255)
+ colors[j] = color_parser(rgb_color, label_rgb)
+
+ return colors
+
+
+def validate_colors_dict(colors, colortype='tuple'):
+ """
+ Validates dictioanry of color(s)
+ """
+ # validate each color element in the dictionary
+ for key in colors:
+ if 'rgb' in colors[key]:
+ colors[key] = color_parser(colors[key], unlabel_rgb)
+ for value in colors[key]:
+ if value > 255.0:
+ raise exceptions.PlotlyError(
+ "Whoops! The elements in your rgb colors "
+ "tuples cannot exceed 255.0."
+ )
+ colors[key] = color_parser(colors[key], unconvert_from_RGB_255)
+
+ if '#' in colors[key]:
+ colors[key] = color_parser(colors[key], hex_to_rgb)
+ colors[key] = color_parser(colors[key], unconvert_from_RGB_255)
+
+ if isinstance(colors[key], tuple):
+ for value in colors[key]:
+ if value > 1.0:
+ raise exceptions.PlotlyError(
+ "Whoops! The elements in your colors tuples "
+ "cannot exceed 1.0."
+ )
+
+ if colortype == 'rgb':
+ for key in colors:
+ colors[key] = color_parser(colors[key], convert_to_RGB_255)
+ colors[key] = color_parser(colors[key], label_rgb)
+
+ return colors
diff --git a/plotly/grid_objs/grid_objs.py b/plotly/grid_objs/grid_objs.py
index 782e3dc2fe0..128d3bf90a1 100644
--- a/plotly/grid_objs/grid_objs.py
+++ b/plotly/grid_objs/grid_objs.py
@@ -5,9 +5,10 @@
"""
from __future__ import absolute_import
-import json
from collections import MutableSequence
+from requests.compat import json as _json
+
from plotly import exceptions, utils
__all__ = None
@@ -66,7 +67,7 @@ def __init__(self, data, name):
def __str__(self):
max_chars = 10
- jdata = json.dumps(self.data, cls=utils.PlotlyJSONEncoder)
+ jdata = _json.dumps(self.data, cls=utils.PlotlyJSONEncoder)
if len(jdata) > max_chars:
data_string = jdata[:max_chars] + "...]"
else:
diff --git a/plotly/offline/offline.py b/plotly/offline/offline.py
index 965b5af09e0..2228d52b00d 100644
--- a/plotly/offline/offline.py
+++ b/plotly/offline/offline.py
@@ -5,7 +5,6 @@
"""
from __future__ import absolute_import
-import json
import os
import uuid
import warnings
@@ -13,22 +12,15 @@
import time
import webbrowser
+from requests.compat import json as _json
+
import plotly
-from plotly import tools, utils
+from plotly import optional_imports, tools, utils
from plotly.exceptions import PlotlyError
-try:
- import IPython
- from IPython.display import HTML, display
- _ipython_imported = True
-except ImportError:
- _ipython_imported = False
-
-try:
- import matplotlib
- _matplotlib_imported = True
-except ImportError:
- _matplotlib_imported = False
+ipython = optional_imports.get_module('IPython')
+ipython_display = optional_imports.get_module('IPython.display')
+matplotlib = optional_imports.get_module('matplotlib')
__PLOTLY_OFFLINE_INITIALIZED = False
@@ -111,7 +103,7 @@ def init_notebook_mode(connected=False):
your notebook, resulting in much larger notebook sizes compared to the case
where `connected=True`.
"""
- if not _ipython_imported:
+ if not ipython:
raise ImportError('`iplot` can only run inside an IPython Notebook.')
global __PLOTLY_OFFLINE_INITIALIZED
@@ -148,7 +140,7 @@ def init_notebook_mode(connected=False):
''
'').format(script=get_plotlyjs())
- display(HTML(script_inject))
+ ipython_display.display(ipython_display.HTML(script_inject))
__PLOTLY_OFFLINE_INITIALIZED = True
@@ -183,10 +175,12 @@ def _plot_html(figure_or_data, config, validate, default_width,
height = str(height) + 'px'
plotdivid = uuid.uuid4()
- jdata = json.dumps(figure.get('data', []), cls=utils.PlotlyJSONEncoder)
- jlayout = json.dumps(figure.get('layout', {}), cls=utils.PlotlyJSONEncoder)
+ jdata = _json.dumps(figure.get('data', []), cls=utils.PlotlyJSONEncoder)
+ jlayout = _json.dumps(figure.get('layout', {}),
+ cls=utils.PlotlyJSONEncoder)
if 'frames' in figure_or_data:
- jframes = json.dumps(figure.get('frames', {}), cls=utils.PlotlyJSONEncoder)
+ jframes = _json.dumps(figure.get('frames', {}),
+ cls=utils.PlotlyJSONEncoder)
configkeys = (
'editable',
@@ -211,7 +205,7 @@ def _plot_html(figure_or_data, config, validate, default_width,
)
config_clean = dict((k, config[k]) for k in configkeys if k in config)
- jconfig = json.dumps(config_clean)
+ jconfig = _json.dumps(config_clean)
# TODO: The get_config 'source of truth' should
# really be somewhere other than plotly.plotly
@@ -330,7 +324,7 @@ def iplot(figure_or_data, show_link=True, link_text='Export to plot.ly',
'plotly.offline.init_notebook_mode() '
'# run at the start of every ipython notebook',
]))
- if not tools._ipython_imported:
+ if not ipython:
raise ImportError('`iplot` can only run inside an IPython Notebook.')
config = {}
@@ -341,7 +335,7 @@ def iplot(figure_or_data, show_link=True, link_text='Export to plot.ly',
figure_or_data, config, validate, '100%', 525, True
)
- display(HTML(plot_html))
+ ipython_display.display(ipython_display.HTML(plot_html))
if image:
if image not in __IMAGE_FORMATS:
@@ -357,7 +351,7 @@ def iplot(figure_or_data, show_link=True, link_text='Export to plot.ly',
# allow time for the plot to draw
time.sleep(1)
# inject code to download an image of the plot
- display(HTML(script))
+ ipython_display.display(ipython_display.HTML(script))
def plot(figure_or_data, show_link=True, link_text='Export to plot.ly',
@@ -687,7 +681,7 @@ def enable_mpl_offline(resize=False, strip_style=False,
"""
init_notebook_mode()
- ip = IPython.core.getipython.get_ipython()
+ ip = ipython.core.getipython.get_ipython()
formatter = ip.display_formatter.formatters['text/html']
formatter.for_type(matplotlib.figure.Figure,
lambda fig: iplot_mpl(fig, resize, strip_style, verbose,
diff --git a/plotly/optional_imports.py b/plotly/optional_imports.py
new file mode 100644
index 00000000000..7e9ba805b42
--- /dev/null
+++ b/plotly/optional_imports.py
@@ -0,0 +1,25 @@
+"""
+Stand-alone module to provide information about whether optional deps exist.
+
+"""
+from __future__ import absolute_import
+
+from importlib import import_module
+
+_not_importable = set()
+
+
+def get_module(name):
+ """
+ Return module or None. Absolute import is required.
+
+ :param (str) name: Dot-separated module path. E.g., 'scipy.stats'.
+ :raise: (ImportError) Only when exc_msg is defined.
+ :return: (module|None) If import succeeds, the module will be returned.
+
+ """
+ if name not in _not_importable:
+ try:
+ return import_module(name)
+ except ImportError:
+ _not_importable.add(name)
diff --git a/plotly/plotly/plotly.py b/plotly/plotly/plotly.py
index 756fc190000..a5bf559df71 100644
--- a/plotly/plotly/plotly.py
+++ b/plotly/plotly/plotly.py
@@ -16,25 +16,22 @@
"""
from __future__ import absolute_import
-import base64
import copy
-import json
import os
import warnings
-import requests
import six
import six.moves
+from requests.compat import json as _json
-from requests.auth import HTTPBasicAuth
-
-from plotly import exceptions, tools, utils, version, files
+from plotly import exceptions, files, session, tools, utils
+from plotly.api import v1, v2
from plotly.plotly import chunked_requests
-from plotly.session import (sign_in, update_session_plot_options,
- get_session_plot_options, get_session_credentials,
- get_session_config)
from plotly.grid_objs import Grid, Column
+# This is imported like this for backwards compat. Careful if changing.
+from plotly.config import get_config, get_credentials
+
__all__ = None
DEFAULT_PLOT_OPTIONS = {
@@ -51,34 +48,16 @@
# don't break backwards compatibility
-sign_in = sign_in
-update_plot_options = update_session_plot_options
-
-
-def get_credentials():
- """Returns the credentials that will be sent to plotly."""
- credentials = tools.get_credentials_file()
- session_credentials = get_session_credentials()
- for credentials_key in credentials:
-
- # checking for not false, but truthy value here is the desired behavior
- session_value = session_credentials.get(credentials_key)
- if session_value is False or session_value:
- credentials[credentials_key] = session_value
- return credentials
-
-
-def get_config():
- """Returns either module config or file config."""
- config = tools.get_config_file()
- session_config = get_session_config()
- for config_key in config:
+def sign_in(username, api_key, **kwargs):
+ session.sign_in(username, api_key, **kwargs)
+ try:
+ # The only way this can succeed is if the user can be authenticated
+ # with the given, username, api_key, and plotly_api_domain.
+ v2.users.current()
+ except exceptions.PlotlyRequestError:
+ raise exceptions.PlotlyError('Sign in failed.')
- # checking for not false, but truthy value here is the desired behavior
- session_value = session_config.get(config_key)
- if session_value is False or session_value:
- config[config_key] = session_value
- return config
+update_plot_options = session.update_session_plot_options
def _plot_option_logic(plot_options_from_call_signature):
@@ -93,7 +72,7 @@ def _plot_option_logic(plot_options_from_call_signature):
"""
default_plot_options = copy.deepcopy(DEFAULT_PLOT_OPTIONS)
file_options = tools.get_config_file()
- session_options = get_session_plot_options()
+ session_options = session.get_session_plot_options()
plot_options_from_call_signature = copy.deepcopy(plot_options_from_call_signature)
# Validate options and fill in defaults w world_readable and sharing
@@ -238,15 +217,23 @@ def plot(figure_or_data, validate=True, **plot_options):
pass
plot_options = _plot_option_logic(plot_options)
- res = _send_to_plotly(figure, **plot_options)
- if res['error'] == '':
- if plot_options['auto_open']:
- _open_url(res['url'])
+ fig = tools._replace_newline(figure) # does not mutate figure
+ data = fig.get('data', [])
+ plot_options['layout'] = fig.get('layout', {})
+ response = v1.clientresp(data, **plot_options)
- return res['url']
- else:
- raise exceptions.PlotlyAccountError(res['error'])
+ # Check if the url needs a secret key
+ url = response.json()['url']
+ if plot_options['sharing'] == 'secret':
+ if 'share_key=' not in url:
+ # add_share_key_to_url updates the url to include the share_key
+ url = add_share_key_to_url(url)
+
+ if plot_options['auto_open']:
+ _open_url(url)
+
+ return url
def iplot_mpl(fig, resize=True, strip_style=False, update=None,
@@ -316,6 +303,64 @@ def plot_mpl(fig, resize=True, strip_style=False, update=None, **plot_options):
return plot(fig, **plot_options)
+def _swap_keys(obj, key1, key2):
+ """Swap obj[key1] with obj[key2]"""
+ val1, val2 = None, None
+ try:
+ val2 = obj.pop(key1)
+ except KeyError:
+ pass
+ try:
+ val1 = obj.pop(key2)
+ except KeyError:
+ pass
+ if val2 is not None:
+ obj[key2] = val2
+ if val1 is not None:
+ obj[key1] = val1
+
+
+def _swap_xy_data(data_obj):
+ """Swap x and y data and references"""
+ swaps = [('x', 'y'),
+ ('x0', 'y0'),
+ ('dx', 'dy'),
+ ('xbins', 'ybins'),
+ ('nbinsx', 'nbinsy'),
+ ('autobinx', 'autobiny'),
+ ('error_x', 'error_y')]
+ for swap in swaps:
+ _swap_keys(data_obj, swap[0], swap[1])
+ try:
+ rows = len(data_obj['z'])
+ cols = len(data_obj['z'][0])
+ for row in data_obj['z']:
+ if len(row) != cols:
+ raise TypeError
+
+ # if we can't do transpose, we hit an exception before here
+ z = data_obj.pop('z')
+ data_obj['z'] = [[0 for rrr in range(rows)] for ccc in range(cols)]
+ for iii in range(rows):
+ for jjj in range(cols):
+ data_obj['z'][jjj][iii] = z[iii][jjj]
+ except (KeyError, TypeError, IndexError) as err:
+ warn = False
+ try:
+ if data_obj['z'] is not None:
+ warn = True
+ if len(data_obj['z']) == 0:
+ warn = False
+ except (KeyError, TypeError):
+ pass
+ if warn:
+ warnings.warn(
+ "Data in this file required an 'xy' swap but the 'z' matrix "
+ "in one of the data objects could not be transposed. Here's "
+ "why:\n\n{}".format(repr(err))
+ )
+
+
def get_figure(file_owner_or_url, file_id=None, raw=False):
"""Returns a JSON figure representation for the specified file
@@ -363,15 +408,6 @@ def get_figure(file_owner_or_url, file_id=None, raw=False):
file_id = url.replace(head, "").split('/')[1]
else:
file_owner = file_owner_or_url
- resource = "/apigetfile/{username}/{file_id}".format(username=file_owner,
- file_id=file_id)
- credentials = get_credentials()
- validate_credentials(credentials)
- username, api_key = credentials['username'], credentials['api_key']
- headers = {'plotly-username': username,
- 'plotly-apikey': api_key,
- 'plotly-version': version.__version__,
- 'plotly-platform': 'python'}
try:
int(file_id)
except ValueError:
@@ -386,28 +422,49 @@ def get_figure(file_owner_or_url, file_id=None, raw=False):
"The 'file_id' argument must be a non-negative number."
)
- response = requests.get(plotly_rest_url + resource,
- headers=headers,
- verify=get_config()['plotly_ssl_verification'])
- if response.status_code == 200:
- if six.PY3:
- content = json.loads(response.content.decode('utf-8'))
- else:
- content = json.loads(response.content)
- response_payload = content['payload']
- figure = response_payload['figure']
- utils.decode_unicode(figure)
- if raw:
- return figure
- else:
- return tools.get_valid_graph_obj(figure, obj_type='Figure')
- else:
+ fid = '{}:{}'.format(file_owner, file_id)
+ response = v2.plots.content(fid, inline_data=True)
+ figure = response.json()
+
+ # Fix 'histogramx', 'histogramy', and 'bardir' stuff
+ for index, entry in enumerate(figure['data']):
try:
- content = json.loads(response.content)
- raise exceptions.PlotlyError(content)
- except:
- raise exceptions.PlotlyError(
- "There was an error retrieving this file")
+ # Use xbins to bin data in x, and ybins to bin data in y
+ if all((entry['type'] == 'histogramy', 'xbins' in entry,
+ 'ybins' not in entry)):
+ entry['ybins'] = entry.pop('xbins')
+
+ # Convert bardir to orientation, and put the data into the axes
+ # it's eventually going to be used with
+ if entry['type'] in ['histogramx', 'histogramy']:
+ entry['type'] = 'histogram'
+ if 'bardir' in entry:
+ entry['orientation'] = entry.pop('bardir')
+ if entry['type'] == 'bar':
+ if entry['orientation'] == 'h':
+ _swap_xy_data(entry)
+ if entry['type'] == 'histogram':
+ if ('x' in entry) and ('y' not in entry):
+ if entry['orientation'] == 'h':
+ _swap_xy_data(entry)
+ del entry['orientation']
+ if ('y' in entry) and ('x' not in entry):
+ if entry['orientation'] == 'v':
+ _swap_xy_data(entry)
+ del entry['orientation']
+ figure['data'][index] = entry
+ except KeyError:
+ pass
+
+ # Remove stream dictionary if found in a data trace
+ # (it has private tokens in there we need to hide!)
+ for index, entry in enumerate(figure['data']):
+ if 'stream' in entry:
+ del figure['data'][index]['stream']
+
+ if raw:
+ return figure
+ return tools.get_valid_graph_obj(figure, obj_type='Figure')
@utils.template_doc(**tools.get_config_file())
@@ -592,7 +649,7 @@ def write(self, trace, layout=None, validate=True,
stream_object.update(dict(layout=layout))
# TODO: allow string version of this?
- jdata = json.dumps(stream_object, cls=utils.PlotlyJSONEncoder)
+ jdata = _json.dumps(stream_object, cls=utils.PlotlyJSONEncoder)
jdata += "\n"
try:
@@ -673,10 +730,6 @@ def get(figure_or_data, format='png', width=None, height=None, scale=None):
"Invalid scale parameter. Scale must be a number."
)
- headers = _api_v2.headers()
- headers['plotly_version'] = version.__version__
- headers['content-type'] = 'application/json'
-
payload = {'figure': figure, 'format': format}
if width is not None:
payload['width'] = width
@@ -684,38 +737,18 @@ def get(figure_or_data, format='png', width=None, height=None, scale=None):
payload['height'] = height
if scale is not None:
payload['scale'] = scale
- url = _api_v2.api_url('images/')
-
- res = requests.post(
- url, data=json.dumps(payload, cls=utils.PlotlyJSONEncoder),
- headers=headers, verify=get_config()['plotly_ssl_verification'],
- )
-
- headers = res.headers
- if res.status_code == 200:
- if ('content-type' in headers and
- headers['content-type'] in ['image/png', 'image/jpeg',
- 'application/pdf',
- 'image/svg+xml']):
- return res.content
+ response = v2.images.create(payload)
- elif ('content-type' in headers and
- 'json' in headers['content-type']):
- return_data = json.loads(res.content)
- return return_data['image']
- else:
- try:
- if ('content-type' in headers and
- 'json' in headers['content-type']):
- return_data = json.loads(res.content)
- else:
- return_data = {'error': res.content}
- except:
- raise exceptions.PlotlyError("The response "
- "from plotly could "
- "not be translated.")
- raise exceptions.PlotlyError(return_data['error'])
+ headers = response.headers
+ if ('content-type' in headers and
+ headers['content-type'] in ['image/png', 'image/jpeg',
+ 'application/pdf',
+ 'image/svg+xml']):
+ return response.content
+ elif ('content-type' in headers and
+ 'json' in headers['content-type']):
+ return response.json()['image']
@classmethod
def ishow(cls, figure_or_data, format='png', width=None, height=None,
@@ -829,22 +862,8 @@ def mkdirs(cls, folder_path):
>> mkdirs('new/folder/path')
"""
- # trim trailing slash TODO: necessesary?
- if folder_path[-1] == '/':
- folder_path = folder_path[0:-1]
-
- payload = {
- 'path': folder_path
- }
-
- url = _api_v2.api_url('folders')
-
- res = requests.post(url, data=payload, headers=_api_v2.headers(),
- verify=get_config()['plotly_ssl_verification'])
-
- _api_v2.response_handler(res)
-
- return res.status_code
+ response = v2.folders.create({'path': folder_path})
+ return response.status_code
class grid_ops:
@@ -874,6 +893,15 @@ def _fill_in_response_column_ids(cls, request_columns,
req_col.id = '{0}:{1}'.format(grid_id, resp_col['uid'])
response_columns.remove(resp_col)
+ @staticmethod
+ def ensure_uploaded(fid):
+ if fid:
+ return
+ raise exceptions.PlotlyError(
+ 'This operation requires that the grid has already been uploaded '
+ 'to Plotly. Try `uploading` first.'
+ )
+
@classmethod
def upload(cls, grid, filename,
world_readable=True, auto_open=True, meta=None):
@@ -954,37 +982,32 @@ def upload(cls, grid, filename,
payload = {
'filename': filename,
- 'data': json.dumps(grid_json, cls=utils.PlotlyJSONEncoder),
+ 'data': grid_json,
'world_readable': world_readable
}
if parent_path != '':
payload['parent_path'] = parent_path
- upload_url = _api_v2.api_url('grids')
+ response = v2.grids.create(payload)
- req = requests.post(upload_url, data=payload,
- headers=_api_v2.headers(),
- verify=get_config()['plotly_ssl_verification'])
-
- res = _api_v2.response_handler(req)
-
- response_columns = res['file']['cols']
- grid_id = res['file']['fid']
- grid_url = res['file']['web_url']
+ parsed_content = response.json()
+ cols = parsed_content['file']['cols']
+ fid = parsed_content['file']['fid']
+ web_url = parsed_content['file']['web_url']
# mutate the grid columns with the id's returned from the server
- cls._fill_in_response_column_ids(grid, response_columns, grid_id)
+ cls._fill_in_response_column_ids(grid, cols, fid)
- grid.id = grid_id
+ grid.id = fid
if meta is not None:
meta_ops.upload(meta, grid=grid)
if auto_open:
- _open_url(grid_url)
+ _open_url(web_url)
- return grid_url
+ return web_url
@classmethod
def append_columns(cls, columns, grid=None, grid_url=None):
@@ -1024,7 +1047,9 @@ def append_columns(cls, columns, grid=None, grid_url=None):
```
"""
- grid_id = _api_v2.parse_grid_id_args(grid, grid_url)
+ grid_id = parse_grid_id_args(grid, grid_url)
+
+ grid_ops.ensure_uploaded(grid_id)
# Verify unique column names
column_names = [c.name for c in columns]
@@ -1036,17 +1061,15 @@ def append_columns(cls, columns, grid=None, grid_url=None):
err = exceptions.NON_UNIQUE_COLUMN_MESSAGE.format(duplicate_name)
raise exceptions.InputError(err)
- payload = {
- 'cols': json.dumps(columns, cls=utils.PlotlyJSONEncoder)
+ # This is sorta gross, we need to double-encode this.
+ body = {
+ 'cols': _json.dumps(columns, cls=utils.PlotlyJSONEncoder)
}
+ fid = grid_id
+ response = v2.grids.col_create(fid, body)
+ parsed_content = response.json()
- api_url = (_api_v2.api_url('grids') +
- '/{grid_id}/col'.format(grid_id=grid_id))
- res = requests.post(api_url, data=payload, headers=_api_v2.headers(),
- verify=get_config()['plotly_ssl_verification'])
- res = _api_v2.response_handler(res)
-
- cls._fill_in_response_column_ids(columns, res['cols'], grid_id)
+ cls._fill_in_response_column_ids(columns, parsed_content['cols'], fid)
if grid:
grid.extend(columns)
@@ -1096,7 +1119,9 @@ def append_rows(cls, rows, grid=None, grid_url=None):
```
"""
- grid_id = _api_v2.parse_grid_id_args(grid, grid_url)
+ grid_id = parse_grid_id_args(grid, grid_url)
+
+ grid_ops.ensure_uploaded(grid_id)
if grid:
n_columns = len([column for column in grid])
@@ -1112,15 +1137,8 @@ def append_rows(cls, rows, grid=None, grid_url=None):
n_columns,
'column' if n_columns == 1 else 'columns'))
- payload = {
- 'rows': json.dumps(rows, cls=utils.PlotlyJSONEncoder)
- }
-
- api_url = (_api_v2.api_url('grids') +
- '/{grid_id}/row'.format(grid_id=grid_id))
- res = requests.post(api_url, data=payload, headers=_api_v2.headers(),
- verify=get_config()['plotly_ssl_verification'])
- _api_v2.response_handler(res)
+ fid = grid_id
+ v2.grids.row(fid, {'rows': rows})
if grid:
longest_column_length = max([len(col.data) for col in grid])
@@ -1168,11 +1186,10 @@ def delete(cls, grid=None, grid_url=None):
```
"""
- grid_id = _api_v2.parse_grid_id_args(grid, grid_url)
- api_url = _api_v2.api_url('grids') + '/' + grid_id
- res = requests.delete(api_url, headers=_api_v2.headers(),
- verify=get_config()['plotly_ssl_verification'])
- _api_v2.response_handler(res)
+ fid = parse_grid_id_args(grid, grid_url)
+ grid_ops.ensure_uploaded(fid)
+ v2.grids.trash(fid)
+ v2.grids.permanent_delete(fid)
class meta_ops:
@@ -1230,269 +1247,99 @@ def upload(cls, meta, grid=None, grid_url=None):
```
"""
- grid_id = _api_v2.parse_grid_id_args(grid, grid_url)
+ fid = parse_grid_id_args(grid, grid_url)
+ return v2.grids.update(fid, {'metadata': meta}).json()
- payload = {
- 'metadata': json.dumps(meta, cls=utils.PlotlyJSONEncoder)
- }
-
- api_url = _api_v2.api_url('grids') + '/{grid_id}'.format(grid_id=grid_id)
-
- res = requests.patch(api_url, data=payload, headers=_api_v2.headers(),
- verify=get_config()['plotly_ssl_verification'])
-
- return _api_v2.response_handler(res)
-
-
-class _api_v2:
- """
- Request and response helper class for communicating with Plotly's v2 API
+def parse_grid_id_args(grid, grid_url):
"""
- @classmethod
- def parse_grid_id_args(cls, grid, grid_url):
- """
- Return the grid_id from the non-None input argument.
-
- Raise an error if more than one argument was supplied.
-
- """
- if grid is not None:
- id_from_grid = grid.id
- else:
- id_from_grid = None
- args = [id_from_grid, grid_url]
- arg_names = ('grid', 'grid_url')
-
- supplied_arg_names = [arg_name for arg_name, arg
- in zip(arg_names, args) if arg is not None]
-
- if not supplied_arg_names:
- raise exceptions.InputError(
- "One of the two keyword arguments is required:\n"
- " `grid` or `grid_url`\n\n"
- "grid: a plotly.graph_objs.Grid object that has already\n"
- " been uploaded to Plotly.\n\n"
- "grid_url: the url where the grid can be accessed on\n"
- " Plotly, e.g. 'https://plot.ly/~chris/3043'\n\n"
- )
- elif len(supplied_arg_names) > 1:
- raise exceptions.InputError(
- "Only one of `grid` or `grid_url` is required. \n"
- "You supplied both. \n"
- )
- else:
- supplied_arg_name = supplied_arg_names.pop()
- if supplied_arg_name == 'grid_url':
- path = six.moves.urllib.parse.urlparse(grid_url).path
- file_owner, file_id = path.replace("/~", "").split('/')[0:2]
- return '{0}:{1}'.format(file_owner, file_id)
- else:
- return grid.id
-
- @classmethod
- def response_handler(cls, response):
- try:
- response.raise_for_status()
- except requests.exceptions.HTTPError as requests_exception:
- if (response.status_code == 404 and
- get_config()['plotly_api_domain']
- != tools.get_config_defaults()['plotly_api_domain']):
- raise exceptions.PlotlyError(
- "This endpoint is unavailable at {url}. If you are using "
- "Plotly On-Premise, you may need to upgrade your Plotly "
- "Plotly On-Premise server to request against this endpoint or "
- "this endpoint may not be available yet.\nQuestions? "
- "Visit community.plot.ly, contact your plotly administrator "
- "or upgrade to a Pro account for 1-1 help: https://goo.gl/1YUVu9 "
- .format(url=get_config()['plotly_api_domain'])
- )
- else:
- raise requests_exception
+ Return the grid_id from the non-None input argument.
- if ('content-type' in response.headers and
- 'json' in response.headers['content-type'] and
- len(response.content) > 0):
-
- response_dict = json.loads(response.content.decode('utf8'))
-
- if 'warnings' in response_dict and len(response_dict['warnings']):
- warnings.warn('\n'.join(response_dict['warnings']))
-
- return response_dict
-
- @classmethod
- def api_url(cls, resource):
- return ('{0}/v2/{1}'.format(get_config()['plotly_api_domain'],
- resource))
-
- @classmethod
- def headers(cls):
- credentials = get_credentials()
+ Raise an error if more than one argument was supplied.
- # todo, validate here?
- username, api_key = credentials['username'], credentials['api_key']
- encoded_api_auth = base64.b64encode(six.b('{0}:{1}'.format(
- username, api_key))).decode('utf8')
-
- headers = {
- 'plotly-client-platform': 'python {0}'.format(version.__version__)
- }
-
- if get_config()['plotly_proxy_authorization']:
- proxy_username = credentials['proxy_username']
- proxy_password = credentials['proxy_password']
- encoded_proxy_auth = base64.b64encode(six.b('{0}:{1}'.format(
- proxy_username, proxy_password))).decode('utf8')
- headers['authorization'] = 'Basic ' + encoded_proxy_auth
- headers['plotly-authorization'] = 'Basic ' + encoded_api_auth
- else:
- headers['authorization'] = 'Basic ' + encoded_api_auth
-
- return headers
-
-
-def validate_credentials(credentials):
"""
- Currently only checks for truthy username and api_key
-
- """
- username = credentials.get('username')
- api_key = credentials.get('api_key')
- if not username or not api_key:
- raise exceptions.PlotlyLocalCredentialsError()
+ if grid is not None:
+ id_from_grid = grid.id
+ else:
+ id_from_grid = None
+ args = [id_from_grid, grid_url]
+ arg_names = ('grid', 'grid_url')
+
+ supplied_arg_names = [arg_name for arg_name, arg
+ in zip(arg_names, args) if arg is not None]
+
+ if not supplied_arg_names:
+ raise exceptions.InputError(
+ "One of the two keyword arguments is required:\n"
+ " `grid` or `grid_url`\n\n"
+ "grid: a plotly.graph_objs.Grid object that has already\n"
+ " been uploaded to Plotly.\n\n"
+ "grid_url: the url where the grid can be accessed on\n"
+ " Plotly, e.g. 'https://plot.ly/~chris/3043'\n\n"
+ )
+ elif len(supplied_arg_names) > 1:
+ raise exceptions.InputError(
+ "Only one of `grid` or `grid_url` is required. \n"
+ "You supplied both. \n"
+ )
+ else:
+ supplied_arg_name = supplied_arg_names.pop()
+ if supplied_arg_name == 'grid_url':
+ path = six.moves.urllib.parse.urlparse(grid_url).path
+ file_owner, file_id = path.replace("/~", "").split('/')[0:2]
+ return '{0}:{1}'.format(file_owner, file_id)
+ else:
+ return grid.id
-def add_share_key_to_url(plot_url, attempt=0):
+def add_share_key_to_url(plot_url):
"""
Update plot's url to include the secret key
"""
urlsplit = six.moves.urllib.parse.urlparse(plot_url)
- file_owner = urlsplit.path.split('/')[1].split('~')[1]
- file_id = urlsplit.path.split('/')[2]
+ username = urlsplit.path.split('/')[1].split('~')[1]
+ idlocal = urlsplit.path.split('/')[2]
+ fid = '{}:{}'.format(username, idlocal)
- url = _api_v2.api_url("files/") + file_owner + ":" + file_id
- new_response = requests.patch(url,
- headers=_api_v2.headers(),
- data={"share_key_enabled":
- "True",
- "world_readable":
- "False"})
+ body = {'share_key_enabled': True, 'world_readable': False}
+ response = v2.files.update(fid, body)
- _api_v2.response_handler(new_response)
-
- # decode bytes for python 3.3: https://bugs.python.org/issue10976
- str_content = new_response.content.decode('utf-8')
-
- new_response_data = json.loads(str_content)
-
- plot_url += '?share_key=' + new_response_data['share_key']
-
- # sometimes a share key is added, but access is still denied
- # check for access, and retry a couple of times if this is the case
- # https://github.com/plotly/streambed/issues/4089
- embed_url = plot_url.split('?')[0] + '.embed' + plot_url.split('?')[1]
- access_res = requests.get(embed_url)
- if access_res.status_code == 404:
- attempt += 1
- if attempt == 5:
- return plot_url
- plot_url = add_share_key_to_url(plot_url.split('?')[0], attempt)
-
- return plot_url
+ return plot_url + '?share_key=' + response.json()['share_key']
def _send_to_plotly(figure, **plot_options):
fig = tools._replace_newline(figure) # does not mutate figure
- data = json.dumps(fig['data'] if 'data' in fig else [],
- cls=utils.PlotlyJSONEncoder)
- credentials = get_credentials()
- validate_credentials(credentials)
- username = credentials['username']
- api_key = credentials['api_key']
- kwargs = json.dumps(dict(filename=plot_options['filename'],
- fileopt=plot_options['fileopt'],
- world_readable=plot_options['world_readable'],
- sharing=plot_options['sharing'],
- layout=fig['layout'] if 'layout' in fig else {}),
- cls=utils.PlotlyJSONEncoder)
-
- # TODO: It'd be cool to expose the platform for RaspPi and others
- payload = dict(platform='python',
- version=version.__version__,
- args=data,
- un=username,
- key=api_key,
- origin='plot',
- kwargs=kwargs)
-
- url = get_config()['plotly_domain'] + "/clientresp"
-
- r = requests.post(url, data=payload,
- verify=get_config()['plotly_ssl_verification'])
- r.raise_for_status()
- r = json.loads(r.text)
-
- if 'error' in r and r['error'] != '':
- raise exceptions.PlotlyError(r['error'])
-
- # Check if the url needs a secret key
- if (plot_options['sharing'] == 'secret' and
- 'share_key=' not in r['url']):
+ data = fig.get('data', [])
+ response = v1.clientresp(data, **plot_options)
- # add_share_key_to_url updates the url to include the share_key
- r['url'] = add_share_key_to_url(r['url'])
+ parsed_content = response.json()
- if 'error' in r and r['error'] != '':
- print(r['error'])
- if 'warning' in r and r['warning'] != '':
- warnings.warn(r['warning'])
- if 'message' in r and r['message'] != '':
- print(r['message'])
+ # Check if the url needs a secret key
+ if plot_options['sharing'] == 'secret':
+ url = parsed_content['url']
+ if 'share_key=' not in url:
+ # add_share_key_to_url updates the url to include the share_key
+ parsed_content['url'] = add_share_key_to_url(url)
- return r
+ return parsed_content
def get_grid(grid_url, raw=False):
"""
Returns the specified grid as a Grid instance or in JSON/dict form.
+ :param (str) grid_url: The web_url which locates a Plotly grid.
:param (bool) raw: if False, will output a Grid instance of the JSON grid
being retrieved. If True, raw JSON will be returned.
"""
- credentials = get_credentials()
- validate_credentials(credentials)
- username, api_key = credentials['username'], credentials['api_key']
- headers = {'plotly-username': username,
- 'plotly-apikey': api_key,
- 'plotly-version': version.__version__,
- 'plotly-platform': 'python'}
- upload_url = _api_v2.api_url('grids')
-
- # extract path in grid url
- url_path = six.moves.urllib.parse.urlparse(grid_url)[2][2:]
- if url_path[-1] == '/':
- url_path = url_path[0: -1]
- url_path = url_path.replace('/', ':')
-
- meta_get_url = upload_url + '/' + url_path
- get_url = meta_get_url + '/content'
-
- r = requests.get(get_url, headers=headers)
- json_res = json.loads(r.text)
-
- # make request to grab the grid id (fid)
- r_meta = requests.get(meta_get_url, headers=headers)
- r_meta.raise_for_status()
-
- json_res_meta = json.loads(r_meta.text)
- retrieved_grid_id = json_res_meta['fid']
-
- if raw is False:
- return Grid(json_res, retrieved_grid_id)
- else:
- return json_res
+ fid = parse_grid_id_args(None, grid_url)
+ response = v2.grids.content(fid)
+ parsed_content = response.json()
+
+ if raw:
+ return parsed_content
+ return Grid(parsed_content, fid)
def create_animations(figure, filename=None, sharing='public', auto_open=True):
@@ -1661,14 +1508,7 @@ def create_animations(figure, filename=None, sharing='public', auto_open=True):
py.create_animations(figure, 'growing_circles')
```
"""
- credentials = get_credentials()
- validate_credentials(credentials)
- username, api_key = credentials['username'], credentials['api_key']
- auth = HTTPBasicAuth(str(username), str(api_key))
- headers = {'Plotly-Client-Platform': 'python',
- 'content-type': 'application/json'}
-
- json = {
+ body = {
'figure': figure,
'world_readable': True
}
@@ -1682,48 +1522,30 @@ def create_animations(figure, filename=None, sharing='public', auto_open=True):
"automatic folder creation. This means a filename of the form "
"'name1/name2' will just create the plot with that name only."
)
- json['filename'] = filename
+ body['filename'] = filename
# set sharing
if sharing == 'public':
- json['world_readable'] = True
+ body['world_readable'] = True
elif sharing == 'private':
- json['world_readable'] = False
+ body['world_readable'] = False
elif sharing == 'secret':
- json['world_readable'] = False
- json['share_key_enabled'] = True
+ body['world_readable'] = False
+ body['share_key_enabled'] = True
else:
raise exceptions.PlotlyError(
"Whoops, sharing can only be set to either 'public', 'private', "
"or 'secret'."
)
- api_url = _api_v2.api_url('plots')
- r = requests.post(api_url, auth=auth, headers=headers, json=json)
-
- try:
- parsed_response = r.json()
- except:
- parsed_response = r.content
-
- # raise error message
- if not r.ok:
- message = ''
- if isinstance(parsed_response, dict):
- errors = parsed_response.get('errors')
- if errors and errors[-1].get('message'):
- message = errors[-1]['message']
- if message:
- raise exceptions.PlotlyError(message)
- else:
- # shucks, we're stuck with a generic error...
- r.raise_for_status()
+ response = v2.plots.create(body)
+ parsed_content = response.json()
if sharing == 'secret':
- web_url = (parsed_response['file']['web_url'][:-1] +
- '?share_key=' + parsed_response['file']['share_key'])
+ web_url = (parsed_content['file']['web_url'][:-1] +
+ '?share_key=' + parsed_content['file']['share_key'])
else:
- web_url = parsed_response['file']['web_url']
+ web_url = parsed_content['file']['web_url']
if auto_open:
_open_url(web_url)
@@ -1738,7 +1560,6 @@ def icreate_animations(figure, filename=None, sharing='public', auto_open=False)
This function is based off `plotly.plotly.iplot`. See `plotly.plotly.
create_animations` Doc String for param descriptions.
"""
- # Still needs doing: create a wrapper for iplot and icreate_animations
url = create_animations(figure, filename, sharing, auto_open)
if isinstance(figure, dict):
diff --git a/plotly/session.py b/plotly/session.py
index e93d9a85996..2e72d45bff3 100644
--- a/plotly/session.py
+++ b/plotly/session.py
@@ -22,6 +22,8 @@
CREDENTIALS_KEYS = {
'username': six.string_types,
'api_key': six.string_types,
+ 'proxy_username': six.string_types,
+ 'proxy_password': six.string_types,
'stream_ids': list
}
diff --git a/plotly/tests/test_core/test_api/__init__.py b/plotly/tests/test_core/test_api/__init__.py
new file mode 100644
index 00000000000..f8a93ee0238
--- /dev/null
+++ b/plotly/tests/test_core/test_api/__init__.py
@@ -0,0 +1,59 @@
+from __future__ import absolute_import
+
+from mock import patch
+from requests import Response
+
+from plotly.session import sign_in
+from plotly.tests.utils import PlotlyTestCase
+
+
+class PlotlyApiTestCase(PlotlyTestCase):
+
+ def mock(self, path_string):
+ patcher = patch(path_string)
+ new_mock = patcher.start()
+ self.addCleanup(patcher.stop)
+ return new_mock
+
+ def setUp(self):
+
+ super(PlotlyApiTestCase, self).setUp()
+
+ self.username = 'foo'
+ self.api_key = 'bar'
+
+ self.proxy_username = 'cnet'
+ self.proxy_password = 'hoopla'
+ self.stream_ids = ['heyThere']
+
+ self.plotly_api_domain = 'https://api.do.not.exist'
+ self.plotly_domain = 'https://who.am.i'
+ self.plotly_proxy_authorization = False
+ self.plotly_streaming_domain = 'stream.does.not.exist'
+ self.plotly_ssl_verification = True
+
+ sign_in(
+ username=self.username,
+ api_key=self.api_key,
+ proxy_username=self.proxy_username,
+ proxy_password=self.proxy_password,
+ stream_ids = self.stream_ids,
+ plotly_domain=self.plotly_domain,
+ plotly_api_domain=self.plotly_api_domain,
+ plotly_streaming_domain=self.plotly_streaming_domain,
+ plotly_proxy_authorization=self.plotly_proxy_authorization,
+ plotly_ssl_verification=self.plotly_ssl_verification
+ )
+
+ def to_bytes(self, string):
+ try:
+ return string.encode('utf-8')
+ except AttributeError:
+ return string
+
+ def get_response(self, content=b'', status_code=200):
+ response = Response()
+ response.status_code = status_code
+ response._content = content
+ response.encoding = 'utf-8'
+ return response
diff --git a/plotly/tests/test_core/test_api/test_v1/__init__.py b/plotly/tests/test_core/test_api/test_v1/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/plotly/tests/test_core/test_api/test_v1/test_clientresp.py b/plotly/tests/test_core/test_api/test_v1/test_clientresp.py
new file mode 100644
index 00000000000..784ca087642
--- /dev/null
+++ b/plotly/tests/test_core/test_api/test_v1/test_clientresp.py
@@ -0,0 +1,61 @@
+from __future__ import absolute_import
+
+from plotly import version
+from plotly.api.v1 import clientresp
+from plotly.tests.test_core.test_api import PlotlyApiTestCase
+
+
+class Duck(object):
+ def to_plotly_json(self):
+ return 'what else floats?'
+
+
+class ClientrespTest(PlotlyApiTestCase):
+
+ def setUp(self):
+ super(ClientrespTest, self).setUp()
+
+ # Mock the actual api call, we don't want to do network tests here.
+ self.request_mock = self.mock('plotly.api.v1.utils.requests.request')
+ self.request_mock.return_value = self.get_response(b'{}', 200)
+
+ # Mock the validation function since we can test that elsewhere.
+ self.mock('plotly.api.v1.utils.validate_response')
+
+ def test_data_only(self):
+ data = [{'y': [3, 5], 'name': Duck()}]
+ clientresp(data)
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'post')
+ self.assertEqual(url, '{}/clientresp'.format(self.plotly_domain))
+ expected_data = ({
+ 'origin': 'plot',
+ 'args': '[{"name": "what else floats?", "y": [3, 5]}]',
+ 'platform': 'python', 'version': version.__version__, 'key': 'bar',
+ 'kwargs': '{}', 'un': 'foo'
+ })
+ self.assertEqual(kwargs['data'], expected_data)
+ self.assertTrue(kwargs['verify'])
+ self.assertEqual(kwargs['headers'], {})
+
+ def test_data_and_kwargs(self):
+ data = [{'y': [3, 5], 'name': Duck()}]
+ clientresp_kwargs = {'layout': {'title': 'mah plot'}, 'filename': 'ok'}
+ clientresp(data, **clientresp_kwargs)
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'post')
+ self.assertEqual(url, '{}/clientresp'.format(self.plotly_domain))
+ expected_data = ({
+ 'origin': 'plot',
+ 'args': '[{"name": "what else floats?", "y": [3, 5]}]',
+ 'platform': 'python', 'version': version.__version__, 'key': 'bar',
+ 'kwargs': '{"filename": "ok", "layout": {"title": "mah plot"}}',
+ 'un': 'foo'
+ })
+ self.assertEqual(kwargs['data'], expected_data)
+ self.assertTrue(kwargs['verify'])
+ self.assertEqual(kwargs['headers'], {})
diff --git a/plotly/tests/test_core/test_api/test_v1/test_utils.py b/plotly/tests/test_core/test_api/test_v1/test_utils.py
new file mode 100644
index 00000000000..dee352db785
--- /dev/null
+++ b/plotly/tests/test_core/test_api/test_v1/test_utils.py
@@ -0,0 +1,175 @@
+from __future__ import absolute_import
+
+from unittest import TestCase
+
+from mock import MagicMock, patch
+from requests import Response
+from requests.compat import json as _json
+from requests.exceptions import ConnectionError
+
+from plotly.api.utils import to_native_utf8_string
+from plotly.api.v1 import utils
+from plotly.exceptions import PlotlyError, PlotlyRequestError
+from plotly.session import sign_in
+from plotly.tests.test_core.test_api import PlotlyApiTestCase
+from plotly.tests.utils import PlotlyTestCase
+
+
+class ValidateResponseTest(PlotlyApiTestCase):
+
+ def test_validate_ok(self):
+ try:
+ utils.validate_response(self.get_response(content=b'{}'))
+ except PlotlyRequestError:
+ self.fail('Expected this to pass!')
+
+ def test_validate_not_ok(self):
+ bad_status_codes = (400, 404, 500)
+ for bad_status_code in bad_status_codes:
+ response = self.get_response(content=b'{}',
+ status_code=bad_status_code)
+ self.assertRaises(PlotlyRequestError, utils.validate_response,
+ response)
+
+ def test_validate_no_content(self):
+
+ # We shouldn't flake if the response has no content.
+
+ response = self.get_response(content=b'', status_code=200)
+ try:
+ utils.validate_response(response)
+ except PlotlyRequestError as e:
+ self.assertEqual(e.message, 'No Content')
+ self.assertEqual(e.status_code, 200)
+ self.assertEqual(e.content, b'')
+ else:
+ self.fail('Expected this to raise!')
+
+ def test_validate_non_json_content(self):
+ response = self.get_response(content=b'foobar', status_code=200)
+ try:
+ utils.validate_response(response)
+ except PlotlyRequestError as e:
+ self.assertEqual(e.message, 'foobar')
+ self.assertEqual(e.status_code, 200)
+ self.assertEqual(e.content, b'foobar')
+ else:
+ self.fail('Expected this to raise!')
+
+ def test_validate_json_content_array(self):
+ content = self.to_bytes(_json.dumps([1, 2, 3]))
+ response = self.get_response(content=content, status_code=200)
+ try:
+ utils.validate_response(response)
+ except PlotlyRequestError as e:
+ self.assertEqual(e.message, to_native_utf8_string(content))
+ self.assertEqual(e.status_code, 200)
+ self.assertEqual(e.content, content)
+ else:
+ self.fail('Expected this to raise!')
+
+ def test_validate_json_content_dict_no_error(self):
+ content = self.to_bytes(_json.dumps({'foo': 'bar'}))
+ response = self.get_response(content=content, status_code=400)
+ try:
+ utils.validate_response(response)
+ except PlotlyRequestError as e:
+ self.assertEqual(e.message, to_native_utf8_string(content))
+ self.assertEqual(e.status_code, 400)
+ self.assertEqual(e.content, content)
+ else:
+ self.fail('Expected this to raise!')
+
+ def test_validate_json_content_dict_error_empty(self):
+ content = self.to_bytes(_json.dumps({'error': ''}))
+ response = self.get_response(content=content, status_code=200)
+ try:
+ utils.validate_response(response)
+ except PlotlyRequestError:
+ self.fail('Expected this not to raise!')
+
+ def test_validate_json_content_dict_one_error_ok(self):
+ content = self.to_bytes(_json.dumps({'error': 'not ok!'}))
+ response = self.get_response(content=content, status_code=200)
+ try:
+ utils.validate_response(response)
+ except PlotlyRequestError as e:
+ self.assertEqual(e.message, 'not ok!')
+ self.assertEqual(e.status_code, 200)
+ self.assertEqual(e.content, content)
+ else:
+ self.fail('Expected this to raise!')
+
+
+class GetHeadersTest(PlotlyTestCase):
+
+ def setUp(self):
+ super(GetHeadersTest, self).setUp()
+ self.domain = 'https://foo.bar'
+ self.username = 'hodor'
+ self.api_key = 'secret'
+ sign_in(self.username, self.api_key, proxy_username='kleen-kanteen',
+ proxy_password='hydrated', plotly_proxy_authorization=False)
+
+ def test_normal_auth(self):
+ headers = utils.get_headers()
+ expected_headers = {}
+ self.assertEqual(headers, expected_headers)
+
+ def test_proxy_auth(self):
+ sign_in(self.username, self.api_key, plotly_proxy_authorization=True)
+ headers = utils.get_headers()
+ expected_headers = {
+ 'authorization': 'Basic a2xlZW4ta2FudGVlbjpoeWRyYXRlZA=='
+ }
+ self.assertEqual(headers, expected_headers)
+
+
+class RequestTest(PlotlyTestCase):
+
+ def setUp(self):
+ super(RequestTest, self).setUp()
+ self.domain = 'https://foo.bar'
+ self.username = 'hodor'
+ self.api_key = 'secret'
+ sign_in(self.username, self.api_key, proxy_username='kleen-kanteen',
+ proxy_password='hydrated', plotly_proxy_authorization=False)
+
+ # Mock the actual api call, we don't want to do network tests here.
+ patcher = patch('plotly.api.v1.utils.requests.request')
+ self.request_mock = patcher.start()
+ self.addCleanup(patcher.stop)
+ self.request_mock.return_value = MagicMock(Response)
+
+ # Mock the validation function since we test that elsewhere.
+ patcher = patch('plotly.api.v1.utils.validate_response')
+ self.validate_response_mock = patcher.start()
+ self.addCleanup(patcher.stop)
+
+ self.method = 'get'
+ self.url = 'https://foo.bar.does.not.exist.anywhere'
+
+ def test_request_with_json(self):
+
+ # You can pass along non-native objects in the `json` kwarg for a
+ # requests.request, however, V1 packs up json objects a little
+ # differently, so we don't allow such requests.
+
+ self.assertRaises(PlotlyError, utils.request, self.method,
+ self.url, json={})
+
+ def test_request_with_ConnectionError(self):
+
+ # requests can flake out and not return a response object, we want to
+ # make sure we remain consistent with our errors.
+
+ self.request_mock.side_effect = ConnectionError()
+ self.assertRaises(PlotlyRequestError, utils.request, self.method,
+ self.url)
+
+ def test_request_validate_response(self):
+
+ # Finally, we check details elsewhere, but make sure we do validate.
+
+ utils.request(self.method, self.url)
+ self.validate_response_mock.assert_called_once()
diff --git a/plotly/tests/test_core/test_api/test_v2/__init__.py b/plotly/tests/test_core/test_api/test_v2/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/plotly/tests/test_core/test_api/test_v2/test_files.py b/plotly/tests/test_core/test_api/test_v2/test_files.py
new file mode 100644
index 00000000000..32e4ec99347
--- /dev/null
+++ b/plotly/tests/test_core/test_api/test_v2/test_files.py
@@ -0,0 +1,104 @@
+from __future__ import absolute_import
+
+from plotly.api.v2 import files
+from plotly.tests.test_core.test_api import PlotlyApiTestCase
+
+
+class FilesTest(PlotlyApiTestCase):
+
+ def setUp(self):
+ super(FilesTest, self).setUp()
+
+ # Mock the actual api call, we don't want to do network tests here.
+ self.request_mock = self.mock('plotly.api.v2.utils.requests.request')
+ self.request_mock.return_value = self.get_response()
+
+ # Mock the validation function since we can test that elsewhere.
+ self.mock('plotly.api.v2.utils.validate_response')
+
+ def test_retrieve(self):
+ files.retrieve('hodor:88')
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'get')
+ self.assertEqual(
+ url, '{}/v2/files/hodor:88'.format(self.plotly_api_domain)
+ )
+ self.assertEqual(kwargs['params'], {})
+
+ def test_retrieve_share_key(self):
+ files.retrieve('hodor:88', share_key='foobar')
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'get')
+ self.assertEqual(
+ url, '{}/v2/files/hodor:88'.format(self.plotly_api_domain)
+ )
+ self.assertEqual(kwargs['params'], {'share_key': 'foobar'})
+
+ def test_update(self):
+ new_filename = '..zzZ ..zzZ'
+ files.update('hodor:88', body={'filename': new_filename})
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'put')
+ self.assertEqual(
+ url, '{}/v2/files/hodor:88'.format(self.plotly_api_domain)
+ )
+ self.assertEqual(kwargs['data'],
+ '{{"filename": "{}"}}'.format(new_filename))
+
+ def test_trash(self):
+ files.trash('hodor:88')
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'post')
+ self.assertEqual(
+ url, '{}/v2/files/hodor:88/trash'.format(self.plotly_api_domain)
+ )
+
+ def test_restore(self):
+ files.restore('hodor:88')
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'post')
+ self.assertEqual(
+ url, '{}/v2/files/hodor:88/restore'.format(self.plotly_api_domain)
+ )
+
+ def test_permanent_delete(self):
+ files.permanent_delete('hodor:88')
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'delete')
+ self.assertEqual(
+ url,
+ '{}/v2/files/hodor:88/permanent_delete'
+ .format(self.plotly_api_domain)
+ )
+
+ def test_lookup(self):
+
+ # requests does urlencode, so don't worry about the `' '` character!
+
+ path = '/mah plot'
+ parent = 43
+ user = 'someone'
+ exists = True
+ files.lookup(path=path, parent=parent, user=user, exists=exists)
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ expected_params = {'path': path, 'parent': parent, 'exists': 'true',
+ 'user': user}
+ self.assertEqual(method, 'get')
+ self.assertEqual(
+ url, '{}/v2/files/lookup'.format(self.plotly_api_domain)
+ )
+ self.assertEqual(kwargs['params'], expected_params)
diff --git a/plotly/tests/test_core/test_api/test_v2/test_folders.py b/plotly/tests/test_core/test_api/test_v2/test_folders.py
new file mode 100644
index 00000000000..0365ad79879
--- /dev/null
+++ b/plotly/tests/test_core/test_api/test_v2/test_folders.py
@@ -0,0 +1,114 @@
+from __future__ import absolute_import
+
+from plotly.api.v2 import folders
+from plotly.tests.test_core.test_api import PlotlyApiTestCase
+
+
+class FoldersTest(PlotlyApiTestCase):
+
+ def setUp(self):
+ super(FoldersTest, self).setUp()
+
+ # Mock the actual api call, we don't want to do network tests here.
+ self.request_mock = self.mock('plotly.api.v2.utils.requests.request')
+ self.request_mock.return_value = self.get_response()
+
+ # Mock the validation function since we can test that elsewhere.
+ self.mock('plotly.api.v2.utils.validate_response')
+
+ def test_create(self):
+ path = '/foo/man/bar/'
+ folders.create({'path': path})
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'post')
+ self.assertEqual(url, '{}/v2/folders'.format(self.plotly_api_domain))
+ self.assertEqual(kwargs['data'], '{{"path": "{}"}}'.format(path))
+
+ def test_retrieve(self):
+ folders.retrieve('hodor:88')
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'get')
+ self.assertEqual(
+ url, '{}/v2/folders/hodor:88'.format(self.plotly_api_domain)
+ )
+ self.assertEqual(kwargs['params'], {})
+
+ def test_retrieve_share_key(self):
+ folders.retrieve('hodor:88', share_key='foobar')
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'get')
+ self.assertEqual(
+ url, '{}/v2/folders/hodor:88'.format(self.plotly_api_domain)
+ )
+ self.assertEqual(kwargs['params'], {'share_key': 'foobar'})
+
+ def test_update(self):
+ new_filename = '..zzZ ..zzZ'
+ folders.update('hodor:88', body={'filename': new_filename})
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'put')
+ self.assertEqual(
+ url, '{}/v2/folders/hodor:88'.format(self.plotly_api_domain)
+ )
+ self.assertEqual(kwargs['data'],
+ '{{"filename": "{}"}}'.format(new_filename))
+
+ def test_trash(self):
+ folders.trash('hodor:88')
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'post')
+ self.assertEqual(
+ url, '{}/v2/folders/hodor:88/trash'.format(self.plotly_api_domain)
+ )
+
+ def test_restore(self):
+ folders.restore('hodor:88')
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'post')
+ self.assertEqual(
+ url, '{}/v2/folders/hodor:88/restore'.format(self.plotly_api_domain)
+ )
+
+ def test_permanent_delete(self):
+ folders.permanent_delete('hodor:88')
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'delete')
+ self.assertEqual(
+ url,
+ '{}/v2/folders/hodor:88/permanent_delete'
+ .format(self.plotly_api_domain)
+ )
+
+ def test_lookup(self):
+
+ # requests does urlencode, so don't worry about the `' '` character!
+
+ path = '/mah folder'
+ parent = 43
+ user = 'someone'
+ exists = True
+ folders.lookup(path=path, parent=parent, user=user, exists=exists)
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ expected_params = {'path': path, 'parent': parent, 'exists': 'true',
+ 'user': user}
+ self.assertEqual(method, 'get')
+ self.assertEqual(
+ url, '{}/v2/folders/lookup'.format(self.plotly_api_domain)
+ )
+ self.assertEqual(kwargs['params'], expected_params)
diff --git a/plotly/tests/test_core/test_api/test_v2/test_grids.py b/plotly/tests/test_core/test_api/test_v2/test_grids.py
new file mode 100644
index 00000000000..ff6fb3ec1b3
--- /dev/null
+++ b/plotly/tests/test_core/test_api/test_v2/test_grids.py
@@ -0,0 +1,185 @@
+from __future__ import absolute_import
+
+from requests.compat import json as _json
+
+from plotly.api.v2 import grids
+from plotly.tests.test_core.test_api import PlotlyApiTestCase
+
+
+class GridsTest(PlotlyApiTestCase):
+
+ def setUp(self):
+ super(GridsTest, self).setUp()
+
+ # Mock the actual api call, we don't want to do network tests here.
+ self.request_mock = self.mock('plotly.api.v2.utils.requests.request')
+ self.request_mock.return_value = self.get_response()
+
+ # Mock the validation function since we can test that elsewhere.
+ self.mock('plotly.api.v2.utils.validate_response')
+
+ def test_create(self):
+ filename = 'a grid'
+ grids.create({'filename': filename})
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'post')
+ self.assertEqual(url, '{}/v2/grids'.format(self.plotly_api_domain))
+ self.assertEqual(
+ kwargs['data'], '{{"filename": "{}"}}'.format(filename)
+ )
+
+ def test_retrieve(self):
+ grids.retrieve('hodor:88')
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'get')
+ self.assertEqual(
+ url, '{}/v2/grids/hodor:88'.format(self.plotly_api_domain)
+ )
+ self.assertEqual(kwargs['params'], {})
+
+ def test_retrieve_share_key(self):
+ grids.retrieve('hodor:88', share_key='foobar')
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'get')
+ self.assertEqual(
+ url, '{}/v2/grids/hodor:88'.format(self.plotly_api_domain)
+ )
+ self.assertEqual(kwargs['params'], {'share_key': 'foobar'})
+
+ def test_update(self):
+ new_filename = '..zzZ ..zzZ'
+ grids.update('hodor:88', body={'filename': new_filename})
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'put')
+ self.assertEqual(
+ url, '{}/v2/grids/hodor:88'.format(self.plotly_api_domain)
+ )
+ self.assertEqual(kwargs['data'],
+ '{{"filename": "{}"}}'.format(new_filename))
+
+ def test_trash(self):
+ grids.trash('hodor:88')
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'post')
+ self.assertEqual(
+ url, '{}/v2/grids/hodor:88/trash'.format(self.plotly_api_domain)
+ )
+
+ def test_restore(self):
+ grids.restore('hodor:88')
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'post')
+ self.assertEqual(
+ url, '{}/v2/grids/hodor:88/restore'.format(self.plotly_api_domain)
+ )
+
+ def test_permanent_delete(self):
+ grids.permanent_delete('hodor:88')
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'delete')
+ self.assertEqual(
+ url,
+ '{}/v2/grids/hodor:88/permanent_delete'
+ .format(self.plotly_api_domain)
+ )
+
+ def test_lookup(self):
+
+ # requests does urlencode, so don't worry about the `' '` character!
+
+ path = '/mah grid'
+ parent = 43
+ user = 'someone'
+ exists = True
+ grids.lookup(path=path, parent=parent, user=user, exists=exists)
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ expected_params = {'path': path, 'parent': parent, 'exists': 'true',
+ 'user': user}
+ self.assertEqual(method, 'get')
+ self.assertEqual(
+ url, '{}/v2/grids/lookup'.format(self.plotly_api_domain)
+ )
+ self.assertEqual(kwargs['params'], expected_params)
+
+ def test_col_create(self):
+ cols = [
+ {'name': 'foo', 'data': [1, 2, 3]},
+ {'name': 'bar', 'data': ['b', 'a', 'r']},
+ ]
+ body = {'cols': _json.dumps(cols, sort_keys=True)}
+ grids.col_create('hodor:88', body)
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'post')
+ self.assertEqual(
+ url, '{}/v2/grids/hodor:88/col'.format(self.plotly_api_domain)
+ )
+ self.assertEqual(kwargs['data'], _json.dumps(body, sort_keys=True))
+
+ def test_col_retrieve(self):
+ grids.col_retrieve('hodor:88', 'aaaaaa,bbbbbb')
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'get')
+ self.assertEqual(
+ url, '{}/v2/grids/hodor:88/col'.format(self.plotly_api_domain)
+ )
+ self.assertEqual(kwargs['params'], {'uid': 'aaaaaa,bbbbbb'})
+
+ def test_col_update(self):
+ cols = [
+ {'name': 'foo', 'data': [1, 2, 3]},
+ {'name': 'bar', 'data': ['b', 'a', 'r']},
+ ]
+ body = {'cols': _json.dumps(cols, sort_keys=True)}
+ grids.col_update('hodor:88', 'aaaaaa,bbbbbb', body)
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'put')
+ self.assertEqual(
+ url, '{}/v2/grids/hodor:88/col'.format(self.plotly_api_domain)
+ )
+ self.assertEqual(kwargs['params'], {'uid': 'aaaaaa,bbbbbb'})
+ self.assertEqual(kwargs['data'], _json.dumps(body, sort_keys=True))
+
+ def test_col_delete(self):
+ grids.col_delete('hodor:88', 'aaaaaa,bbbbbb')
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'delete')
+ self.assertEqual(
+ url, '{}/v2/grids/hodor:88/col'.format(self.plotly_api_domain)
+ )
+ self.assertEqual(kwargs['params'], {'uid': 'aaaaaa,bbbbbb'})
+
+ def test_row(self):
+ body = {'rows': [[1, 'A'], [2, 'B']]}
+ grids.row('hodor:88', body)
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'post')
+ self.assertEqual(
+ url, '{}/v2/grids/hodor:88/row'.format(self.plotly_api_domain)
+ )
+ self.assertEqual(kwargs['data'], _json.dumps(body, sort_keys=True))
diff --git a/plotly/tests/test_core/test_api/test_v2/test_images.py b/plotly/tests/test_core/test_api/test_v2/test_images.py
new file mode 100644
index 00000000000..480cf0f05bf
--- /dev/null
+++ b/plotly/tests/test_core/test_api/test_v2/test_images.py
@@ -0,0 +1,41 @@
+from __future__ import absolute_import
+
+from requests.compat import json as _json
+
+from plotly.api.v2 import images
+from plotly.tests.test_core.test_api import PlotlyApiTestCase
+
+
+class ImagesTest(PlotlyApiTestCase):
+
+ def setUp(self):
+ super(ImagesTest, self).setUp()
+
+ # Mock the actual api call, we don't want to do network tests here.
+ self.request_mock = self.mock('plotly.api.v2.utils.requests.request')
+ self.request_mock.return_value = self.get_response()
+
+ # Mock the validation function since we can test that elsewhere.
+ self.mock('plotly.api.v2.utils.validate_response')
+
+ def test_create(self):
+
+ body = {
+ "figure": {
+ "data": [{"y": [10, 10, 2, 20]}],
+ "layout": {"width": 700}
+ },
+ "width": 1000,
+ "height": 500,
+ "format": "png",
+ "scale": 4,
+ "encoded": False
+ }
+
+ images.create(body)
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'post')
+ self.assertEqual(url, '{}/v2/images'.format(self.plotly_api_domain))
+ self.assertEqual(kwargs['data'], _json.dumps(body, sort_keys=True))
diff --git a/plotly/tests/test_core/test_api/test_v2/test_plot_schema.py b/plotly/tests/test_core/test_api/test_v2/test_plot_schema.py
new file mode 100644
index 00000000000..b52f1b3a000
--- /dev/null
+++ b/plotly/tests/test_core/test_api/test_v2/test_plot_schema.py
@@ -0,0 +1,30 @@
+from __future__ import absolute_import
+
+from plotly.api.v2 import plot_schema
+from plotly.tests.test_core.test_api import PlotlyApiTestCase
+
+
+class PlotSchemaTest(PlotlyApiTestCase):
+
+ def setUp(self):
+ super(PlotSchemaTest, self).setUp()
+
+ # Mock the actual api call, we don't want to do network tests here.
+ self.request_mock = self.mock('plotly.api.v2.utils.requests.request')
+ self.request_mock.return_value = self.get_response()
+
+ # Mock the validation function since we can test that elsewhere.
+ self.mock('plotly.api.v2.utils.validate_response')
+
+ def test_retrieve(self):
+
+ plot_schema.retrieve('some-hash', timeout=400)
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'get')
+ self.assertEqual(
+ url, '{}/v2/plot-schema'.format(self.plotly_api_domain)
+ )
+ self.assertTrue(kwargs['timeout'])
+ self.assertEqual(kwargs['params'], {'sha1': 'some-hash'})
diff --git a/plotly/tests/test_core/test_api/test_v2/test_plots.py b/plotly/tests/test_core/test_api/test_v2/test_plots.py
new file mode 100644
index 00000000000..31d50cb7aaf
--- /dev/null
+++ b/plotly/tests/test_core/test_api/test_v2/test_plots.py
@@ -0,0 +1,116 @@
+from __future__ import absolute_import
+
+from plotly.api.v2 import plots
+from plotly.tests.test_core.test_api import PlotlyApiTestCase
+
+
+class PlotsTest(PlotlyApiTestCase):
+
+ def setUp(self):
+ super(PlotsTest, self).setUp()
+
+ # Mock the actual api call, we don't want to do network tests here.
+ self.request_mock = self.mock('plotly.api.v2.utils.requests.request')
+ self.request_mock.return_value = self.get_response()
+
+ # Mock the validation function since we can test that elsewhere.
+ self.mock('plotly.api.v2.utils.validate_response')
+
+ def test_create(self):
+ filename = 'a plot'
+ plots.create({'filename': filename})
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'post')
+ self.assertEqual(url, '{}/v2/plots'.format(self.plotly_api_domain))
+ self.assertEqual(
+ kwargs['data'], '{{"filename": "{}"}}'.format(filename)
+ )
+
+ def test_retrieve(self):
+ plots.retrieve('hodor:88')
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'get')
+ self.assertEqual(
+ url, '{}/v2/plots/hodor:88'.format(self.plotly_api_domain)
+ )
+ self.assertEqual(kwargs['params'], {})
+
+ def test_retrieve_share_key(self):
+ plots.retrieve('hodor:88', share_key='foobar')
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'get')
+ self.assertEqual(
+ url, '{}/v2/plots/hodor:88'.format(self.plotly_api_domain)
+ )
+ self.assertEqual(kwargs['params'], {'share_key': 'foobar'})
+
+ def test_update(self):
+ new_filename = '..zzZ ..zzZ'
+ plots.update('hodor:88', body={'filename': new_filename})
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'put')
+ self.assertEqual(
+ url, '{}/v2/plots/hodor:88'.format(self.plotly_api_domain)
+ )
+ self.assertEqual(kwargs['data'],
+ '{{"filename": "{}"}}'.format(new_filename))
+
+ def test_trash(self):
+ plots.trash('hodor:88')
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'post')
+ self.assertEqual(
+ url, '{}/v2/plots/hodor:88/trash'.format(self.plotly_api_domain)
+ )
+
+ def test_restore(self):
+ plots.restore('hodor:88')
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'post')
+ self.assertEqual(
+ url, '{}/v2/plots/hodor:88/restore'.format(self.plotly_api_domain)
+ )
+
+ def test_permanent_delete(self):
+ plots.permanent_delete('hodor:88')
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'delete')
+ self.assertEqual(
+ url,
+ '{}/v2/plots/hodor:88/permanent_delete'
+ .format(self.plotly_api_domain)
+ )
+
+ def test_lookup(self):
+
+ # requests does urlencode, so don't worry about the `' '` character!
+
+ path = '/mah plot'
+ parent = 43
+ user = 'someone'
+ exists = True
+ plots.lookup(path=path, parent=parent, user=user, exists=exists)
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ expected_params = {'path': path, 'parent': parent, 'exists': 'true',
+ 'user': user}
+ self.assertEqual(method, 'get')
+ self.assertEqual(
+ url, '{}/v2/plots/lookup'.format(self.plotly_api_domain)
+ )
+ self.assertEqual(kwargs['params'], expected_params)
diff --git a/plotly/tests/test_core/test_api/test_v2/test_users.py b/plotly/tests/test_core/test_api/test_v2/test_users.py
new file mode 100644
index 00000000000..59cf8731d56
--- /dev/null
+++ b/plotly/tests/test_core/test_api/test_v2/test_users.py
@@ -0,0 +1,28 @@
+from __future__ import absolute_import
+
+from plotly.api.v2 import users
+from plotly.tests.test_core.test_api import PlotlyApiTestCase
+
+
+class UsersTest(PlotlyApiTestCase):
+
+ def setUp(self):
+ super(UsersTest, self).setUp()
+
+ # Mock the actual api call, we don't want to do network tests here.
+ self.request_mock = self.mock('plotly.api.v2.utils.requests.request')
+ self.request_mock.return_value = self.get_response()
+
+ # Mock the validation function since we can test that elsewhere.
+ self.mock('plotly.api.v2.utils.validate_response')
+
+ def test_current(self):
+ users.current()
+ self.request_mock.assert_called_once()
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ self.assertEqual(method, 'get')
+ self.assertEqual(
+ url, '{}/v2/users/current'.format(self.plotly_api_domain)
+ )
+ self.assertNotIn('params', kwargs)
diff --git a/plotly/tests/test_core/test_api/test_v2/test_utils.py b/plotly/tests/test_core/test_api/test_v2/test_utils.py
new file mode 100644
index 00000000000..c370ef418ed
--- /dev/null
+++ b/plotly/tests/test_core/test_api/test_v2/test_utils.py
@@ -0,0 +1,252 @@
+from __future__ import absolute_import
+
+from requests.compat import json as _json
+from requests.exceptions import ConnectionError
+
+from plotly import version
+from plotly.api.utils import to_native_utf8_string
+from plotly.api.v2 import utils
+from plotly.exceptions import PlotlyRequestError
+from plotly.session import sign_in
+from plotly.tests.test_core.test_api import PlotlyApiTestCase
+
+
+class MakeParamsTest(PlotlyApiTestCase):
+
+ def test_make_params(self):
+ params = utils.make_params(foo='FOO', bar=None)
+ self.assertEqual(params, {'foo': 'FOO'})
+
+ def test_make_params_empty(self):
+ params = utils.make_params(foo=None, bar=None)
+ self.assertEqual(params, {})
+
+
+class BuildUrlTest(PlotlyApiTestCase):
+
+ def test_build_url(self):
+ url = utils.build_url('cats')
+ self.assertEqual(url, '{}/v2/cats'.format(self.plotly_api_domain))
+
+ def test_build_url_id(self):
+ url = utils.build_url('cats', id='MsKitty')
+ self.assertEqual(
+ url, '{}/v2/cats/MsKitty'.format(self.plotly_api_domain)
+ )
+
+ def test_build_url_route(self):
+ url = utils.build_url('cats', route='about')
+ self.assertEqual(
+ url, '{}/v2/cats/about'.format(self.plotly_api_domain)
+ )
+
+ def test_build_url_id_route(self):
+ url = utils.build_url('cats', id='MsKitty', route='de-claw')
+ self.assertEqual(
+ url, '{}/v2/cats/MsKitty/de-claw'.format(self.plotly_api_domain)
+ )
+
+
+class ValidateResponseTest(PlotlyApiTestCase):
+
+ def test_validate_ok(self):
+ try:
+ utils.validate_response(self.get_response())
+ except PlotlyRequestError:
+ self.fail('Expected this to pass!')
+
+ def test_validate_not_ok(self):
+ bad_status_codes = (400, 404, 500)
+ for bad_status_code in bad_status_codes:
+ response = self.get_response(status_code=bad_status_code)
+ self.assertRaises(PlotlyRequestError, utils.validate_response,
+ response)
+
+ def test_validate_no_content(self):
+
+ # We shouldn't flake if the response has no content.
+
+ response = self.get_response(content=b'', status_code=400)
+ try:
+ utils.validate_response(response)
+ except PlotlyRequestError as e:
+ self.assertEqual(e.message, u'No Content')
+ self.assertEqual(e.status_code, 400)
+ self.assertEqual(e.content.decode('utf-8'), u'')
+ else:
+ self.fail('Expected this to raise!')
+
+ def test_validate_non_json_content(self):
+ response = self.get_response(content=b'foobar', status_code=400)
+ try:
+ utils.validate_response(response)
+ except PlotlyRequestError as e:
+ self.assertEqual(e.message, 'foobar')
+ self.assertEqual(e.status_code, 400)
+ self.assertEqual(e.content, b'foobar')
+ else:
+ self.fail('Expected this to raise!')
+
+ def test_validate_json_content_array(self):
+ content = self.to_bytes(_json.dumps([1, 2, 3]))
+ response = self.get_response(content=content, status_code=400)
+ try:
+ utils.validate_response(response)
+ except PlotlyRequestError as e:
+ self.assertEqual(e.message, to_native_utf8_string(content))
+ self.assertEqual(e.status_code, 400)
+ self.assertEqual(e.content, content)
+ else:
+ self.fail('Expected this to raise!')
+
+ def test_validate_json_content_dict_no_errors(self):
+ content = self.to_bytes(_json.dumps({'foo': 'bar'}))
+ response = self.get_response(content=content, status_code=400)
+ try:
+ utils.validate_response(response)
+ except PlotlyRequestError as e:
+ self.assertEqual(e.message, to_native_utf8_string(content))
+ self.assertEqual(e.status_code, 400)
+ self.assertEqual(e.content, content)
+ else:
+ self.fail('Expected this to raise!')
+
+ def test_validate_json_content_dict_one_error_bad(self):
+ content = self.to_bytes(_json.dumps({'errors': [{}]}))
+ response = self.get_response(content=content, status_code=400)
+ try:
+ utils.validate_response(response)
+ except PlotlyRequestError as e:
+ self.assertEqual(e.message, to_native_utf8_string(content))
+ self.assertEqual(e.status_code, 400)
+ self.assertEqual(e.content, content)
+ else:
+ self.fail('Expected this to raise!')
+
+ content = self.to_bytes(_json.dumps({'errors': [{'message': ''}]}))
+ response = self.get_response(content=content, status_code=400)
+ try:
+ utils.validate_response(response)
+ except PlotlyRequestError as e:
+ self.assertEqual(e.message, to_native_utf8_string(content))
+ self.assertEqual(e.status_code, 400)
+ self.assertEqual(e.content, content)
+ else:
+ self.fail('Expected this to raise!')
+
+ def test_validate_json_content_dict_one_error_ok(self):
+ content = self.to_bytes(_json.dumps(
+ {'errors': [{'message': 'not ok!'}]}))
+ response = self.get_response(content=content, status_code=400)
+ try:
+ utils.validate_response(response)
+ except PlotlyRequestError as e:
+ self.assertEqual(e.message, 'not ok!')
+ self.assertEqual(e.status_code, 400)
+ self.assertEqual(e.content, content)
+ else:
+ self.fail('Expected this to raise!')
+
+ def test_validate_json_content_dict_multiple_errors(self):
+ content = self.to_bytes(_json.dumps({'errors': [
+ {'message': 'not ok!'}, {'message': 'bad job...'}
+ ]}))
+ response = self.get_response(content=content, status_code=400)
+ try:
+ utils.validate_response(response)
+ except PlotlyRequestError as e:
+ self.assertEqual(e.message, 'not ok!\nbad job...')
+ self.assertEqual(e.status_code, 400)
+ self.assertEqual(e.content, content)
+ else:
+ self.fail('Expected this to raise!')
+
+
+class GetHeadersTest(PlotlyApiTestCase):
+
+ def test_normal_auth(self):
+ headers = utils.get_headers()
+ expected_headers = {
+ 'plotly-client-platform': 'python {}'.format(version.__version__),
+ 'authorization': 'Basic Zm9vOmJhcg==',
+ 'content-type': 'application/json'
+ }
+ self.assertEqual(headers, expected_headers)
+
+ def test_proxy_auth(self):
+ sign_in(self.username, self.api_key, plotly_proxy_authorization=True)
+ headers = utils.get_headers()
+ expected_headers = {
+ 'plotly-client-platform': 'python {}'.format(version.__version__),
+ 'authorization': 'Basic Y25ldDpob29wbGE=',
+ 'plotly-authorization': 'Basic Zm9vOmJhcg==',
+ 'content-type': 'application/json'
+ }
+ self.assertEqual(headers, expected_headers)
+
+
+class RequestTest(PlotlyApiTestCase):
+
+ def setUp(self):
+ super(RequestTest, self).setUp()
+
+ # Mock the actual api call, we don't want to do network tests here.
+ self.request_mock = self.mock('plotly.api.v2.utils.requests.request')
+ self.request_mock.return_value = self.get_response()
+
+ # Mock the validation function since we can test that elsewhere.
+ self.validate_response_mock = self.mock(
+ 'plotly.api.v2.utils.validate_response')
+
+ self.method = 'get'
+ self.url = 'https://foo.bar.does.not.exist.anywhere'
+
+ def test_request_with_params(self):
+
+ # urlencode transforms `True` --> `'True'`, which isn't super helpful,
+ # Our backend accepts the JS `true`, so we want `True` --> `'true'`.
+
+ params = {'foo': True, 'bar': 'True', 'baz': False, 'zap': 0}
+ utils.request(self.method, self.url, params=params)
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ expected_params = {'foo': 'true', 'bar': 'True', 'baz': 'false',
+ 'zap': 0}
+ self.assertEqual(method, self.method)
+ self.assertEqual(url, self.url)
+ self.assertEqual(kwargs['params'], expected_params)
+
+ def test_request_with_non_native_objects(self):
+
+ # We always send along json, but it may contain non-native objects like
+ # a pandas array or a Column reference. Make sure that's handled in one
+ # central place.
+
+ class Duck(object):
+ def to_plotly_json(self):
+ return 'what else floats?'
+
+ utils.request(self.method, self.url, json={'foo': [Duck(), Duck()]})
+ args, kwargs = self.request_mock.call_args
+ method, url = args
+ expected_data = '{"foo": ["what else floats?", "what else floats?"]}'
+ self.assertEqual(method, self.method)
+ self.assertEqual(url, self.url)
+ self.assertEqual(kwargs['data'], expected_data)
+ self.assertNotIn('json', kwargs)
+
+ def test_request_with_ConnectionError(self):
+
+ # requests can flake out and not return a response object, we want to
+ # make sure we remain consistent with our errors.
+
+ self.request_mock.side_effect = ConnectionError()
+ self.assertRaises(PlotlyRequestError, utils.request, self.method,
+ self.url)
+
+ def test_request_validate_response(self):
+
+ # Finally, we check details elsewhere, but make sure we do validate.
+
+ utils.request(self.method, self.url)
+ self.validate_response_mock.assert_called_once()
diff --git a/plotly/tests/test_core/test_file/test_file.py b/plotly/tests/test_core/test_file/test_file.py
index cd4ea5b9a47..c8b3bb8680a 100644
--- a/plotly/tests/test_core/test_file/test_file.py
+++ b/plotly/tests/test_core/test_file/test_file.py
@@ -49,7 +49,7 @@ def test_duplicate_folders(self):
py.file_ops.mkdirs(first_folder)
try:
py.file_ops.mkdirs(first_folder)
- except requests.exceptions.RequestException as e:
- self.assertTrue(400 <= e.response.status_code < 500)
+ except PlotlyRequestError as e:
+ self.assertTrue(400 <= e.status_code < 500)
else:
self.fail('Expected this to fail!')
diff --git a/plotly/tests/test_core/test_get_requests/test_get_requests.py b/plotly/tests/test_core/test_get_requests/test_get_requests.py
index 4c4fd939e47..1719d86b38d 100644
--- a/plotly/tests/test_core/test_get_requests/test_get_requests.py
+++ b/plotly/tests/test_core/test_get_requests/test_get_requests.py
@@ -6,11 +6,11 @@
"""
import copy
-import json
-import requests
+import requests
import six
from nose.plugins.attrib import attr
+from requests.compat import json as _json
default_headers = {'plotly-username': '',
@@ -37,9 +37,9 @@ def test_user_does_not_exist():
resource = "/apigetfile/{0}/{1}/".format(file_owner, file_id)
response = requests.get(server + resource, headers=hd)
if six.PY3:
- content = json.loads(response.content.decode('unicode_escape'))
+ content = _json.loads(response.content.decode('unicode_escape'))
else:
- content = json.loads(response.content)
+ content = _json.loads(response.content)
print(response.status_code)
print(content)
assert response.status_code == 404
@@ -60,9 +60,9 @@ def test_file_does_not_exist():
resource = "/apigetfile/{0}/{1}/".format(file_owner, file_id)
response = requests.get(server + resource, headers=hd)
if six.PY3:
- content = json.loads(response.content.decode('unicode_escape'))
+ content = _json.loads(response.content.decode('unicode_escape'))
else:
- content = json.loads(response.content)
+ content = _json.loads(response.content)
print(response.status_code)
print(content)
assert response.status_code == 404
@@ -100,9 +100,9 @@ def test_private_permission_defined():
resource = "/apigetfile/{0}/{1}/".format(file_owner, file_id)
response = requests.get(server + resource, headers=hd)
if six.PY3:
- content = json.loads(response.content.decode('unicode_escape'))
+ content = _json.loads(response.content.decode('unicode_escape'))
else:
- content = json.loads(response.content)
+ content = _json.loads(response.content)
print(response.status_code)
print(content)
assert response.status_code == 403
@@ -122,9 +122,9 @@ def test_missing_headers():
del hd[header]
response = requests.get(server + resource, headers=hd)
if six.PY3:
- content = json.loads(response.content.decode('unicode_escape'))
+ content = _json.loads(response.content.decode('unicode_escape'))
else:
- content = json.loads(response.content)
+ content = _json.loads(response.content)
print(response.status_code)
print(content)
assert response.status_code == 422
@@ -142,13 +142,13 @@ def test_valid_request():
resource = "/apigetfile/{0}/{1}/".format(file_owner, file_id)
response = requests.get(server + resource, headers=hd)
if six.PY3:
- content = json.loads(response.content.decode('unicode_escape'))
+ content = _json.loads(response.content.decode('unicode_escape'))
else:
- content = json.loads(response.content)
+ content = _json.loads(response.content)
print(response.status_code)
print(content)
assert response.status_code == 200
- # content = json.loads(res.content)
+ # content = _json.loads(res.content)
# response_payload = content['payload']
# figure = response_payload['figure']
# if figure['data'][0]['x'] != [u'1', u'2', u'3']:
diff --git a/plotly/tests/test_core/test_graph_reference/test_graph_reference.py b/plotly/tests/test_core/test_graph_reference/test_graph_reference.py
index f3f97c7c5a8..005d33a7c05 100644
--- a/plotly/tests/test_core/test_graph_reference/test_graph_reference.py
+++ b/plotly/tests/test_core/test_graph_reference/test_graph_reference.py
@@ -4,36 +4,31 @@
"""
from __future__ import absolute_import
-import json
import os
from pkg_resources import resource_string
from unittest import TestCase
-import requests
-import six
from nose.plugins.attrib import attr
+from requests.compat import json as _json
-from plotly import files, graph_reference as gr
+from plotly import graph_reference as gr
+from plotly.api import v2
from plotly.graph_reference import string_to_class_name, get_role
from plotly.tests.utils import PlotlyTestCase
+FAKE_API_DOMAIN = 'https://api.am.not.here.ly'
+
class TestGraphReferenceCaching(PlotlyTestCase):
@attr('slow')
def test_default_schema_is_up_to_date(self):
- api_domain = files.FILE_CONTENT[files.CONFIG_FILE]['plotly_api_domain']
- graph_reference_url = '{}{}?sha1'.format(api_domain, '/v2/plot-schema')
- response = requests.get(graph_reference_url)
- if six.PY3:
- content = str(response.content, encoding='utf-8')
- else:
- content = response.content
- schema = json.loads(content)['schema']
+ response = v2.plot_schema.retrieve('')
+ schema = response.json()['schema']
path = os.path.join('package_data', 'default-schema.json')
s = resource_string('plotly', path).decode('utf-8')
- default_schema = json.loads(s)
+ default_schema = _json.loads(s)
msg = (
'The default, hard-coded plot schema we ship with pip is out of '
diff --git a/plotly/tests/test_core/test_grid/test_grid.py b/plotly/tests/test_core/test_grid/test_grid.py
index d711fd8fd73..4ccd3690136 100644
--- a/plotly/tests/test_core/test_grid/test_grid.py
+++ b/plotly/tests/test_core/test_grid/test_grid.py
@@ -9,7 +9,6 @@
import random
import string
-import requests
from nose import with_setup
from nose.plugins.attrib import attr
@@ -17,10 +16,10 @@
from unittest import skip
import plotly.plotly as py
-from plotly.exceptions import InputError, PlotlyRequestError
+from plotly.exceptions import InputError, PlotlyRequestError, PlotlyError
from plotly.graph_objs import Scatter
from plotly.grid_objs import Column, Grid
-from plotly.plotly.plotly import _api_v2
+from plotly.plotly.plotly import parse_grid_id_args
def random_filename():
@@ -124,19 +123,18 @@ def test_get_figure_from_references():
def test_grid_id_args():
assert(
- _api_v2.parse_grid_id_args(_grid, None) ==
- _api_v2.parse_grid_id_args(None, _grid_url)
+ parse_grid_id_args(_grid, None) == parse_grid_id_args(None, _grid_url)
)
@raises(InputError)
def test_no_grid_id_args():
- _api_v2.parse_grid_id_args(None, None)
+ parse_grid_id_args(None, None)
@raises(InputError)
def test_overspecified_grid_args():
- _api_v2.parse_grid_id_args(_grid, _grid_url)
+ parse_grid_id_args(_grid, _grid_url)
# Out of order usage
@@ -149,8 +147,7 @@ def test_scatter_from_non_uploaded_grid():
Scatter(xsrc=g[0], ysrc=g[1])
-@attr('slow')
-@raises(requests.exceptions.HTTPError)
+@raises(PlotlyError)
def test_column_append_of_non_uploaded_grid():
c1 = Column([1, 2, 3, 4], 'first column')
c2 = Column(['a', 'b', 'c', 'd'], 'second column')
@@ -158,8 +155,7 @@ def test_column_append_of_non_uploaded_grid():
py.grid_ops.append_columns([c2], grid=g)
-@attr('slow')
-@raises(requests.exceptions.HTTPError)
+@raises(PlotlyError)
def test_row_append_of_non_uploaded_grid():
c1 = Column([1, 2, 3, 4], 'first column')
rows = [[1], [2]]
diff --git a/plotly/tests/test_core/test_offline/test_offline.py b/plotly/tests/test_core/test_offline/test_offline.py
index f845709a287..cc4b903de71 100644
--- a/plotly/tests/test_core/test_offline/test_offline.py
+++ b/plotly/tests/test_core/test_offline/test_offline.py
@@ -4,12 +4,13 @@
"""
from __future__ import absolute_import
-from nose.tools import raises
from unittest import TestCase
-from plotly.tests.utils import PlotlyTestCase
-import json
+
+from requests.compat import json as _json
import plotly
+from plotly.tests.utils import PlotlyTestCase
+
fig = {
'data': [
@@ -35,8 +36,9 @@ def _read_html(self, file_url):
return f.read()
def test_default_plot_generates_expected_html(self):
- data_json = json.dumps(fig['data'], cls=plotly.utils.PlotlyJSONEncoder)
- layout_json = json.dumps(
+ data_json = _json.dumps(fig['data'],
+ cls=plotly.utils.PlotlyJSONEncoder)
+ layout_json = _json.dumps(
fig['layout'],
cls=plotly.utils.PlotlyJSONEncoder)
diff --git a/plotly/tests/test_core/test_optional_imports/__init__.py b/plotly/tests/test_core/test_optional_imports/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/plotly/tests/test_core/test_optional_imports/test_optional_imports.py b/plotly/tests/test_core/test_optional_imports/test_optional_imports.py
new file mode 100644
index 00000000000..e7569f1609d
--- /dev/null
+++ b/plotly/tests/test_core/test_optional_imports/test_optional_imports.py
@@ -0,0 +1,24 @@
+from __future__ import absolute_import
+
+from unittest import TestCase
+
+from plotly.optional_imports import get_module
+
+
+class OptionalImportsTest(TestCase):
+
+ def test_get_module_exists(self):
+ import math
+ module = get_module('math')
+ self.assertIsNotNone(module)
+ self.assertEqual(math, module)
+
+ def test_get_module_exists_submodule(self):
+ import requests.sessions
+ module = get_module('requests.sessions')
+ self.assertIsNotNone(module)
+ self.assertEqual(requests.sessions, module)
+
+ def test_get_module_does_not_exist(self):
+ module = get_module('hoopla')
+ self.assertIsNone(module)
diff --git a/plotly/tests/test_core/test_plotly/test_credentials.py b/plotly/tests/test_core/test_plotly/test_credentials.py
index 573b30c91b0..73b9eca8767 100644
--- a/plotly/tests/test_core/test_plotly/test_credentials.py
+++ b/plotly/tests/test_core/test_plotly/test_credentials.py
@@ -1,39 +1,43 @@
from __future__ import absolute_import
-from unittest import TestCase
+from mock import patch
import plotly.plotly.plotly as py
import plotly.session as session
import plotly.tools as tls
+from plotly import exceptions
+from plotly.tests.utils import PlotlyTestCase
-def test_get_credentials():
- session_credentials = session.get_session_credentials()
- if 'username' in session_credentials:
- del session._session['credentials']['username']
- if 'api_key' in session_credentials:
- del session._session['credentials']['api_key']
- creds = py.get_credentials()
- file_creds = tls.get_credentials_file()
- print(creds)
- print(file_creds)
- assert creds == file_creds
+class TestSignIn(PlotlyTestCase):
+ def setUp(self):
+ super(TestSignIn, self).setUp()
+ patcher = patch('plotly.api.v2.users.current')
+ self.users_current_mock = patcher.start()
+ self.addCleanup(patcher.stop)
-def test_sign_in():
- un = 'anyone'
- ak = 'something'
- # TODO, add this!
- # si = ['this', 'and-this']
- py.sign_in(un, ak)
- creds = py.get_credentials()
- assert creds['username'] == un
- assert creds['api_key'] == ak
- # TODO, and check it!
- # assert creds['stream_ids'] == si
+ def test_get_credentials(self):
+ session_credentials = session.get_session_credentials()
+ if 'username' in session_credentials:
+ del session._session['credentials']['username']
+ if 'api_key' in session_credentials:
+ del session._session['credentials']['api_key']
+ creds = py.get_credentials()
+ file_creds = tls.get_credentials_file()
+ self.assertEqual(creds, file_creds)
-
-class TestSignIn(TestCase):
+ def test_sign_in(self):
+ un = 'anyone'
+ ak = 'something'
+ # TODO, add this!
+ # si = ['this', 'and-this']
+ py.sign_in(un, ak)
+ creds = py.get_credentials()
+ self.assertEqual(creds['username'], un)
+ self.assertEqual(creds['api_key'], ak)
+ # TODO, and check it!
+ # assert creds['stream_ids'] == si
def test_get_config(self):
plotly_domain = 'test domain'
@@ -74,3 +78,10 @@ def test_sign_in_with_config(self):
self.assertEqual(
config['plotly_ssl_verification'], plotly_ssl_verification
)
+
+ def test_sign_in_cannot_validate(self):
+ self.users_current_mock.side_effect = exceptions.PlotlyRequestError(
+ 'msg', 400, 'foobar'
+ )
+ with self.assertRaisesRegexp(exceptions.PlotlyError, 'Sign in failed'):
+ py.sign_in('foo', 'bar')
diff --git a/plotly/tests/test_core/test_plotly/test_plot.py b/plotly/tests/test_core/test_plotly/test_plot.py
index 25b6d208aa3..ef7e797cf58 100644
--- a/plotly/tests/test_core/test_plotly/test_plot.py
+++ b/plotly/tests/test_core/test_plotly/test_plot.py
@@ -7,11 +7,12 @@
"""
from __future__ import absolute_import
-import json
import requests
import six
+from requests.compat import json as _json
from unittest import TestCase
+from mock import patch
from nose.plugins.attrib import attr
from nose.tools import raises
@@ -40,9 +41,14 @@ def test_plot_valid(self):
'x': [1, 2, 3],
'y': [2, 1, 2]
}
- ]
+ ],
+ 'layout': {'title': 'simple'}
}
- py.plot(fig, auto_open=False, filename='plot_valid')
+ url = py.plot(fig, auto_open=False, filename='plot_valid')
+ saved_fig = py.get_figure(url)
+ self.assertEqual(saved_fig['data'][0]['x'], fig['data'][0]['x'])
+ self.assertEqual(saved_fig['data'][0]['y'], fig['data'][0]['y'])
+ self.assertEqual(saved_fig['layout']['title'], fig['layout']['title'])
@raises(PlotlyError)
def test_plot_invalid(self):
@@ -223,6 +229,14 @@ class TestPlotOptionLogic(PlotlyTestCase):
{'world_readable': False, 'sharing': 'public'}
)
+ def setUp(self):
+ super(TestPlotOptionLogic, self).setUp()
+
+ # Make sure we don't hit sign-in validation failures.
+ patcher = patch('plotly.api.v2.users.current')
+ self.users_current_mock = patcher.start()
+ self.addCleanup(patcher.stop)
+
def test_default_options(self):
options = py._plot_option_logic({})
config_options = tls.get_config_file()
@@ -296,10 +310,10 @@ def generate_conflicting_plot_options_with_json_writes_of_config():
"""
def gen_test(plot_options):
def test(self):
- config = json.load(open(CONFIG_FILE))
+ config = _json.load(open(CONFIG_FILE))
with open(CONFIG_FILE, 'w') as f:
config.update(plot_options)
- f.write(json.dumps(config))
+ f.write(_json.dumps(config))
self.assertRaises(PlotlyError, py._plot_option_logic, {})
return test
diff --git a/plotly/tests/test_core/test_utils/test_utils.py b/plotly/tests/test_core/test_utils/test_utils.py
index cb38648b8b6..b406a6464ab 100644
--- a/plotly/tests/test_core/test_utils/test_utils.py
+++ b/plotly/tests/test_core/test_utils/test_utils.py
@@ -1,8 +1,9 @@
from __future__ import absolute_import
-import json
from unittest import TestCase
+from requests.compat import json as _json
+
from plotly.utils import PlotlyJSONEncoder, get_by_path, node_generator
@@ -10,7 +11,7 @@ class TestJSONEncoder(TestCase):
def test_nan_to_null(self):
array = [1, float('NaN'), float('Inf'), float('-Inf'), 'platypus']
- result = json.dumps(array, cls=PlotlyJSONEncoder)
+ result = _json.dumps(array, cls=PlotlyJSONEncoder)
expected_result = '[1, null, null, null, "platypus"]'
self.assertEqual(result, expected_result)
diff --git a/plotly/tests/test_optional/optional_utils.py b/plotly/tests/test_optional/optional_utils.py
index 74308de7f53..76941b1b1fe 100644
--- a/plotly/tests/test_optional/optional_utils.py
+++ b/plotly/tests/test_optional/optional_utils.py
@@ -2,12 +2,13 @@
import numpy as np
+from plotly import optional_imports
from plotly.tests.utils import is_num_list
from plotly.utils import get_by_path, node_generator
-# TODO: matplotlib-build-wip
-from plotly.tools import _matplotlylib_imported
-if _matplotlylib_imported:
+matplotlylib = optional_imports.get_module('plotly.matplotlylib')
+
+if matplotlylib:
import matplotlib
# Force matplotlib to not use any Xwindows backend.
matplotlib.use('Agg')
diff --git a/plotly/tests/test_optional/test_matplotlylib/test_annotations.py b/plotly/tests/test_optional/test_matplotlylib/test_annotations.py
index d4987d01b70..b0e238416b9 100644
--- a/plotly/tests/test_optional/test_matplotlylib/test_annotations.py
+++ b/plotly/tests/test_optional/test_matplotlylib/test_annotations.py
@@ -2,10 +2,11 @@
from nose.plugins.attrib import attr
-# TODO: matplotlib-build-wip
-from plotly.tools import _matplotlylib_imported
+from plotly import optional_imports
-if _matplotlylib_imported:
+matplotlylib = optional_imports.get_module('plotly.matplotlylib')
+
+if matplotlylib:
import matplotlib
# Force matplotlib to not use any Xwindows backend.
matplotlib.use('Agg')
diff --git a/plotly/tests/test_optional/test_matplotlylib/test_axis_scales.py b/plotly/tests/test_optional/test_matplotlylib/test_axis_scales.py
index 277b886d9fd..2b62f28ceab 100644
--- a/plotly/tests/test_optional/test_matplotlylib/test_axis_scales.py
+++ b/plotly/tests/test_optional/test_matplotlylib/test_axis_scales.py
@@ -2,14 +2,14 @@
from nose.plugins.attrib import attr
+from plotly import optional_imports
from plotly.tests.utils import compare_dict
from plotly.tests.test_optional.optional_utils import run_fig
from plotly.tests.test_optional.test_matplotlylib.data.axis_scales import *
-# TODO: matplotlib-build-wip
-from plotly.tools import _matplotlylib_imported
+matplotlylib = optional_imports.get_module('plotly.matplotlylib')
-if _matplotlylib_imported:
+if matplotlylib:
import matplotlib
# Force matplotlib to not use any Xwindows backend.
matplotlib.use('Agg')
diff --git a/plotly/tests/test_optional/test_matplotlylib/test_bars.py b/plotly/tests/test_optional/test_matplotlylib/test_bars.py
index feb439dd2c8..c89bfb3a8fc 100644
--- a/plotly/tests/test_optional/test_matplotlylib/test_bars.py
+++ b/plotly/tests/test_optional/test_matplotlylib/test_bars.py
@@ -2,15 +2,15 @@
from nose.plugins.attrib import attr
+from plotly import optional_imports
from plotly.tests.utils import compare_dict
from plotly.tests.test_optional.optional_utils import run_fig
from plotly.tests.test_optional.test_matplotlylib.data.bars import *
-# TODO: matplotlib-build-wip
-from plotly.tools import _matplotlylib_imported
-if _matplotlylib_imported:
- import matplotlib
+matplotlylib = optional_imports.get_module('plotly.matplotlylib')
+if matplotlylib:
+ import matplotlib
# Force matplotlib to not use any Xwindows backend.
matplotlib.use('Agg')
import matplotlib.pyplot as plt
diff --git a/plotly/tests/test_optional/test_matplotlylib/test_data.py b/plotly/tests/test_optional/test_matplotlylib/test_data.py
index 33ad00543a3..b188dbc2d46 100644
--- a/plotly/tests/test_optional/test_matplotlylib/test_data.py
+++ b/plotly/tests/test_optional/test_matplotlylib/test_data.py
@@ -2,12 +2,13 @@
from nose.plugins.attrib import attr
+from plotly import optional_imports
from plotly.tests.test_optional.optional_utils import run_fig
from plotly.tests.test_optional.test_matplotlylib.data.data import *
-# TODO: matplotlib-build-wip
-from plotly.tools import _matplotlylib_imported
-if _matplotlylib_imported:
+matplotlylib = optional_imports.get_module('plotly.matplotlylib')
+
+if matplotlylib:
import matplotlib
# Force matplotlib to not use any Xwindows backend.
diff --git a/plotly/tests/test_optional/test_matplotlylib/test_date_times.py b/plotly/tests/test_optional/test_matplotlylib/test_date_times.py
index 0c56efd5ad6..5f1bf59aefd 100644
--- a/plotly/tests/test_optional/test_matplotlylib/test_date_times.py
+++ b/plotly/tests/test_optional/test_matplotlylib/test_date_times.py
@@ -8,10 +8,11 @@
from nose.plugins.attrib import attr
import plotly.tools as tls
+from plotly import optional_imports
-# TODO: matplotlib-build-wip
-from plotly.tools import _matplotlylib_imported
-if _matplotlylib_imported:
+matplotlylib = optional_imports.get_module('plotly.matplotlylib')
+
+if matplotlylib:
import matplotlib
# Force matplotlib to not use any Xwindows backend.
diff --git a/plotly/tests/test_optional/test_matplotlylib/test_lines.py b/plotly/tests/test_optional/test_matplotlylib/test_lines.py
index 9d1074bc571..b7355cccc22 100644
--- a/plotly/tests/test_optional/test_matplotlylib/test_lines.py
+++ b/plotly/tests/test_optional/test_matplotlylib/test_lines.py
@@ -2,13 +2,14 @@
from nose.plugins.attrib import attr
+from plotly import optional_imports
from plotly.tests.utils import compare_dict
from plotly.tests.test_optional.optional_utils import run_fig
from plotly.tests.test_optional.test_matplotlylib.data.lines import *
-# TODO: matplotlib-build-wip
-from plotly.tools import _matplotlylib_imported
-if _matplotlylib_imported:
+matplotlylib = optional_imports.get_module('plotly.matplotlylib')
+
+if matplotlylib:
import matplotlib
# Force matplotlib to not use any Xwindows backend.
diff --git a/plotly/tests/test_optional/test_matplotlylib/test_scatter.py b/plotly/tests/test_optional/test_matplotlylib/test_scatter.py
index 07fc276f5fb..80774c079eb 100644
--- a/plotly/tests/test_optional/test_matplotlylib/test_scatter.py
+++ b/plotly/tests/test_optional/test_matplotlylib/test_scatter.py
@@ -2,13 +2,14 @@
from nose.plugins.attrib import attr
+from plotly import optional_imports
from plotly.tests.utils import compare_dict
from plotly.tests.test_optional.optional_utils import run_fig
from plotly.tests.test_optional.test_matplotlylib.data.scatter import *
-# TODO: matplotlib-build-wip
-from plotly.tools import _matplotlylib_imported
-if _matplotlylib_imported:
+matplotlylib = optional_imports.get_module('plotly.matplotlylib')
+
+if matplotlylib:
import matplotlib
# Force matplotlib to not use any Xwindows backend.
diff --git a/plotly/tests/test_optional/test_matplotlylib/test_subplots.py b/plotly/tests/test_optional/test_matplotlylib/test_subplots.py
index 725ca8aa78b..c30ae23bbf9 100644
--- a/plotly/tests/test_optional/test_matplotlylib/test_subplots.py
+++ b/plotly/tests/test_optional/test_matplotlylib/test_subplots.py
@@ -2,13 +2,14 @@
from nose.plugins.attrib import attr
+from plotly import optional_imports
from plotly.tests.utils import compare_dict
from plotly.tests.test_optional.optional_utils import run_fig
from plotly.tests.test_optional.test_matplotlylib.data.subplots import *
-# TODO: matplotlib-build-wip
-from plotly.tools import _matplotlylib_imported
-if _matplotlylib_imported:
+matplotlylib = optional_imports.get_module('plotly.matplotlylib')
+
+if matplotlylib:
import matplotlib
# Force matplotlib to not use any Xwindows backend.
diff --git a/plotly/tests/test_optional/test_offline/test_offline.py b/plotly/tests/test_optional/test_offline/test_offline.py
index 93d2c4c3770..17dd1af2bdd 100644
--- a/plotly/tests/test_optional/test_offline/test_offline.py
+++ b/plotly/tests/test_optional/test_offline/test_offline.py
@@ -6,16 +6,16 @@
from nose.tools import raises
from nose.plugins.attrib import attr
+from requests.compat import json as _json
from unittest import TestCase
-import json
import plotly
+from plotly import optional_imports
-# TODO: matplotlib-build-wip
-from plotly.tools import _matplotlylib_imported
+matplotlylib = optional_imports.get_module('plotly.matplotlylib')
-if _matplotlylib_imported:
+if matplotlylib:
import matplotlib
# Force matplotlib to not use any Xwindows backend.
matplotlib.use('Agg')
@@ -34,7 +34,6 @@ def test_iplot_doesnt_work_before_you_call_init_notebook_mode(self):
plotly.offline.iplot([{}])
def test_iplot_works_after_you_call_init_notebook_mode(self):
- plotly.tools._ipython_imported = True
plotly.offline.init_notebook_mode()
plotly.offline.iplot([{}])
@@ -47,7 +46,6 @@ def test_iplot_mpl_works_after_you_call_init_notebook_mode(self):
y = [100, 200, 300]
plt.plot(x, y, "o")
- plotly.tools._ipython_imported = True
plotly.offline.init_notebook_mode()
plotly.offline.iplot_mpl(fig)
@@ -75,8 +73,8 @@ def test_default_mpl_plot_generates_expected_html(self):
figure = plotly.tools.mpl_to_plotly(fig)
data = figure['data']
layout = figure['layout']
- data_json = json.dumps(data, cls=plotly.utils.PlotlyJSONEncoder)
- layout_json = json.dumps(layout, cls=plotly.utils.PlotlyJSONEncoder)
+ data_json = _json.dumps(data, cls=plotly.utils.PlotlyJSONEncoder)
+ layout_json = _json.dumps(layout, cls=plotly.utils.PlotlyJSONEncoder)
html = self._read_html(plotly.offline.plot_mpl(fig))
# just make sure a few of the parts are in here
diff --git a/plotly/tests/test_optional/test_plotly/test_plot_mpl.py b/plotly/tests/test_optional/test_plotly/test_plot_mpl.py
index e509d59d82a..876f7f4b9b3 100644
--- a/plotly/tests/test_optional/test_plotly/test_plot_mpl.py
+++ b/plotly/tests/test_optional/test_plotly/test_plot_mpl.py
@@ -10,13 +10,13 @@
from nose.plugins.attrib import attr
from nose.tools import raises
-from plotly import exceptions
+from plotly import exceptions, optional_imports
from plotly.plotly import plotly as py
from unittest import TestCase
-# TODO: matplotlib-build-wip
-from plotly.tools import _matplotlylib_imported
-if _matplotlylib_imported:
+matplotlylib = optional_imports.get_module('plotly.matplotlylib')
+
+if matplotlylib:
import matplotlib
# Force matplotlib to not use any Xwindows backend.
diff --git a/plotly/tests/test_optional/test_tools/__init__.py b/plotly/tests/test_optional/test_tools/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/plotly/tests/test_core/test_tools/test_figure_factory.py b/plotly/tests/test_optional/test_tools/test_figure_factory.py
similarity index 99%
rename from plotly/tests/test_core/test_tools/test_figure_factory.py
rename to plotly/tests/test_optional/test_tools/test_figure_factory.py
index 9ebc68c55f7..170920be197 100644
--- a/plotly/tests/test_core/test_tools/test_figure_factory.py
+++ b/plotly/tests/test_optional/test_tools/test_figure_factory.py
@@ -1649,8 +1649,6 @@ def test_2D_density_all_args(self):
# def test_scipy_import_error(self):
-# # make sure Import Error is raised when _scipy_imported = False
-
# hist_data = [[1.1, 1.1, 2.5, 3.0, 3.5,
# 3.5, 4.1, 4.4, 4.5, 4.5,
# 5.0, 5.0, 5.2, 5.5, 5.5,
diff --git a/plotly/tests/test_optional/test_utils/test_utils.py b/plotly/tests/test_optional/test_utils/test_utils.py
index a72f89a8b60..9531e12ad88 100644
--- a/plotly/tests/test_optional/test_utils/test_utils.py
+++ b/plotly/tests/test_optional/test_utils/test_utils.py
@@ -5,7 +5,6 @@
from __future__ import absolute_import
import datetime
-import json
import math
import decimal
from datetime import datetime as dt
@@ -16,14 +15,15 @@
import pytz
from nose.plugins.attrib import attr
from pandas.util.testing import assert_series_equal
+from requests.compat import json as _json
-from plotly import utils
+from plotly import optional_imports, utils
from plotly.graph_objs import Scatter, Scatter3d, Figure, Data
from plotly.grid_objs import Column
-# TODO: matplotlib-build-wip
-from plotly.tools import _matplotlylib_imported
-if _matplotlylib_imported:
+matplotlylib = optional_imports.get_module('plotly.matplotlylib')
+
+if matplotlylib:
import matplotlib.pyplot as plt
from plotly.matplotlylib import Exporter, PlotlyRenderer
@@ -179,7 +179,7 @@ def test_column_json_encoding():
Column(mixed_list, 'col 2'),
Column(np_list, 'col 3')
]
- json_columns = json.dumps(
+ json_columns = _json.dumps(
columns, cls=utils.PlotlyJSONEncoder, sort_keys=True
)
assert('[{"data": [1, 2, 3], "name": "col 1"}, '
@@ -198,8 +198,8 @@ def test_figure_json_encoding():
data = Data([s1, s2])
figure = Figure(data=data)
- js1 = json.dumps(s1, cls=utils.PlotlyJSONEncoder, sort_keys=True)
- js2 = json.dumps(s2, cls=utils.PlotlyJSONEncoder, sort_keys=True)
+ js1 = _json.dumps(s1, cls=utils.PlotlyJSONEncoder, sort_keys=True)
+ js2 = _json.dumps(s2, cls=utils.PlotlyJSONEncoder, sort_keys=True)
assert(js1 == '{"type": "scatter3d", "x": [1, 2, 3], '
'"y": [1, 2, 3, null, null, null, "2014-01-05"], '
@@ -208,8 +208,8 @@ def test_figure_json_encoding():
assert(js2 == '{"type": "scatter", "x": [1, 2, 3]}')
# Test JSON encoding works
- json.dumps(data, cls=utils.PlotlyJSONEncoder, sort_keys=True)
- json.dumps(figure, cls=utils.PlotlyJSONEncoder, sort_keys=True)
+ _json.dumps(data, cls=utils.PlotlyJSONEncoder, sort_keys=True)
+ _json.dumps(figure, cls=utils.PlotlyJSONEncoder, sort_keys=True)
# Test data wasn't mutated
assert(bool(np.asarray(np_list ==
@@ -221,18 +221,18 @@ def test_figure_json_encoding():
def test_datetime_json_encoding():
- j1 = json.dumps(dt_list, cls=utils.PlotlyJSONEncoder)
+ j1 = _json.dumps(dt_list, cls=utils.PlotlyJSONEncoder)
assert(j1 == '["2014-01-05", '
'"2014-01-05 01:01:01", '
'"2014-01-05 01:01:01.000001"]')
- j2 = json.dumps({"x": dt_list}, cls=utils.PlotlyJSONEncoder)
+ j2 = _json.dumps({"x": dt_list}, cls=utils.PlotlyJSONEncoder)
assert(j2 == '{"x": ["2014-01-05", '
'"2014-01-05 01:01:01", '
'"2014-01-05 01:01:01.000001"]}')
def test_pandas_json_encoding():
- j1 = json.dumps(df['col 1'], cls=utils.PlotlyJSONEncoder)
+ j1 = _json.dumps(df['col 1'], cls=utils.PlotlyJSONEncoder)
assert(j1 == '[1, 2, 3, "2014-01-05", null, null, null]')
# Test that data wasn't mutated
@@ -240,28 +240,28 @@ def test_pandas_json_encoding():
pd.Series([1, 2, 3, dt(2014, 1, 5),
pd.NaT, np.NaN, np.Inf], name='col 1'))
- j2 = json.dumps(df.index, cls=utils.PlotlyJSONEncoder)
+ j2 = _json.dumps(df.index, cls=utils.PlotlyJSONEncoder)
assert(j2 == '[0, 1, 2, 3, 4, 5, 6]')
nat = [pd.NaT]
- j3 = json.dumps(nat, cls=utils.PlotlyJSONEncoder)
+ j3 = _json.dumps(nat, cls=utils.PlotlyJSONEncoder)
assert(j3 == '[null]')
assert(nat[0] is pd.NaT)
- j4 = json.dumps(rng, cls=utils.PlotlyJSONEncoder)
+ j4 = _json.dumps(rng, cls=utils.PlotlyJSONEncoder)
assert(j4 == '["2011-01-01", "2011-01-01 01:00:00"]')
- j5 = json.dumps(ts, cls=utils.PlotlyJSONEncoder)
+ j5 = _json.dumps(ts, cls=utils.PlotlyJSONEncoder)
assert(j5 == '[1.5, 2.5]')
assert_series_equal(ts, pd.Series([1.5, 2.5], index=rng))
- j6 = json.dumps(ts.index, cls=utils.PlotlyJSONEncoder)
+ j6 = _json.dumps(ts.index, cls=utils.PlotlyJSONEncoder)
assert(j6 == '["2011-01-01", "2011-01-01 01:00:00"]')
def test_numpy_masked_json_encoding():
l = [1, 2, np.ma.core.masked]
- j1 = json.dumps(l, cls=utils.PlotlyJSONEncoder)
+ j1 = _json.dumps(l, cls=utils.PlotlyJSONEncoder)
print(j1)
assert(j1 == '[1, 2, null]')
@@ -285,18 +285,18 @@ def test_masked_constants_example():
renderer = PlotlyRenderer()
Exporter(renderer).run(fig)
- json.dumps(renderer.plotly_fig, cls=utils.PlotlyJSONEncoder)
+ _json.dumps(renderer.plotly_fig, cls=utils.PlotlyJSONEncoder)
- jy = json.dumps(renderer.plotly_fig['data'][1]['y'],
+ jy = _json.dumps(renderer.plotly_fig['data'][1]['y'],
cls=utils.PlotlyJSONEncoder)
print(jy)
- array = json.loads(jy)
+ array = _json.loads(jy)
assert(array == [-398.11793027, -398.11792966, -398.11786308, None])
def test_numpy_dates():
a = np.arange(np.datetime64('2011-07-11'), np.datetime64('2011-07-18'))
- j1 = json.dumps(a, cls=utils.PlotlyJSONEncoder)
+ j1 = _json.dumps(a, cls=utils.PlotlyJSONEncoder)
assert(j1 == '["2011-07-11", "2011-07-12", "2011-07-13", '
'"2011-07-14", "2011-07-15", "2011-07-16", '
'"2011-07-17"]')
@@ -304,5 +304,5 @@ def test_numpy_dates():
def test_datetime_dot_date():
a = [datetime.date(2014, 1, 1), datetime.date(2014, 1, 2)]
- j1 = json.dumps(a, cls=utils.PlotlyJSONEncoder)
+ j1 = _json.dumps(a, cls=utils.PlotlyJSONEncoder)
assert(j1 == '["2014-01-01", "2014-01-02"]')
diff --git a/plotly/tests/utils.py b/plotly/tests/utils.py
index f8b1438e6f1..2d1113d68b4 100644
--- a/plotly/tests/utils.py
+++ b/plotly/tests/utils.py
@@ -1,5 +1,4 @@
import copy
-import json
from numbers import Number as Num
from unittest import TestCase
diff --git a/plotly/tools.py b/plotly/tools.py
index e43106a940d..c45d4cdabcd 100644
--- a/plotly/tools.py
+++ b/plotly/tools.py
@@ -8,19 +8,12 @@
"""
from __future__ import absolute_import
-from collections import OrderedDict
import warnings
import six
-import math
-import decimal
-
-from plotly import colors
-from plotly import utils
-from plotly import exceptions
-from plotly import graph_reference
-from plotly import session
+
+from plotly import exceptions, optional_imports, session, utils
from plotly.files import (CONFIG_FILE, CREDENTIALS_FILE, FILE_CONTENT,
check_file_permissions)
@@ -63,55 +56,9 @@ def warning_on_one_line(message, category, filename, lineno,
message)
warnings.formatwarning = warning_on_one_line
-try:
- from . import matplotlylib
- _matplotlylib_imported = True
-except ImportError:
- _matplotlylib_imported = False
-
-try:
- import IPython
- import IPython.core.display
- _ipython_imported = True
-except ImportError:
- _ipython_imported = False
-
-try:
- import numpy as np
- _numpy_imported = True
-except ImportError:
- _numpy_imported = False
-
-try:
- import pandas as pd
- _pandas_imported = True
-except ImportError:
- _pandas_imported = False
-
-try:
- import scipy as scp
- _scipy_imported = True
-except ImportError:
- _scipy_imported = False
-
-try:
- import scipy.spatial as scs
- _scipy__spatial_imported = True
-except ImportError:
- _scipy__spatial_imported = False
-
-try:
- import scipy.cluster.hierarchy as sch
- _scipy__cluster__hierarchy_imported = True
-except ImportError:
- _scipy__cluster__hierarchy_imported = False
-
-try:
- import scipy
- import scipy.stats
- _scipy_imported = True
-except ImportError:
- _scipy_imported = False
+ipython_core_display = optional_imports.get_module('IPython.core.display')
+matplotlylib = optional_imports.get_module('plotly.matplotlylib')
+sage_salvus = optional_imports.get_module('sage_salvus')
def get_config_defaults():
@@ -419,11 +366,11 @@ def embed(file_owner_or_url, file_id=None, width="100%", height=525):
height=height)
# see if we are in the SageMath Cloud
- from sage_salvus import html
- return html(s, hide=False)
+ if sage_salvus:
+ return sage_salvus.html(s, hide=False)
except:
pass
- if _ipython_imported:
+ if ipython_core_display:
if file_id:
plotly_domain = (
session.get_session_config().get('plotly_domain') or
@@ -502,7 +449,7 @@ def mpl_to_plotly(fig, resize=False, strip_style=False, verbose=False):
{plotly_domain}/python/getting-started
"""
- if _matplotlylib_imported:
+ if matplotlylib:
renderer = matplotlylib.PlotlyRenderer()
matplotlylib.Exporter(renderer).run(fig)
if resize:
@@ -1357,6 +1304,7 @@ def validate(obj, obj_type):
"""
# TODO: Deprecate or move. #283
+ from plotly import graph_reference
from plotly.graph_objs import graph_objs
if obj_type not in graph_reference.CLASSES:
@@ -1397,8 +1345,8 @@ def _replace_newline(obj):
return obj # we return the actual reference... but DON'T mutate.
-if _ipython_imported:
- class PlotlyDisplay(IPython.core.display.HTML):
+if ipython_core_display:
+ class PlotlyDisplay(ipython_core_display.HTML):
"""An IPython display object for use with plotly urls
PlotlyDisplay objects should be instantiated with a url for a plot.
@@ -1459,6137 +1407,91 @@ def return_figure_from_figure_or_data(figure_or_data, validate_figure):
class FigureFactory(object):
- """
- BETA functions to create specific chart types.
-
- This is beta as in: subject to change in a backwards incompatible way
- without notice.
-
- Supported chart types include candlestick, open high low close, quiver,
- streamline, distplot, dendrogram, annotated heatmap, and tables. See
- FigureFactory.create_candlestick, FigureFactory.create_ohlc,
- FigureFactory.create_quiver, FigureFactory.create_streamline,
- FigureFactory.create_distplot, FigureFactory.create_dendrogram,
- FigureFactory.create_annotated_heatmap, or FigureFactory.create_table for
- more information and examples of a specific chart type.
- """
-
- @staticmethod
- def _make_colorscale(colors, scale=None):
- """
- Makes a colorscale from a list of colors and scale
-
- Takes a list of colors and scales and constructs a colorscale based
- on the colors in sequential order. If 'scale' is left empty, a linear-
- interpolated colorscale will be generated. If 'scale' is a specificed
- list, it must be the same legnth as colors and must contain all floats
- For documentation regarding to the form of the output, see
- https://plot.ly/python/reference/#mesh3d-colorscale
- """
- colorscale = []
-
- if not scale:
- for j, color in enumerate(colors):
- colorscale.append([j * 1./(len(colors) - 1), color])
- return colorscale
-
- else:
- colorscale = [list(tup) for tup in zip(scale, colors)]
- return colorscale
-
- @staticmethod
- def _convert_colorscale_to_rgb(colorscale):
- """
- Converts the colors in a colorscale to rgb colors
-
- A colorscale is an array of arrays, each with a numeric value as the
- first item and a color as the second. This function specifically is
- converting a colorscale with tuple colors (each coordinate between 0
- and 1) into a colorscale with the colors transformed into rgb colors
- """
- for color in colorscale:
- color[1] = FigureFactory._convert_to_RGB_255(
- color[1]
- )
-
- for color in colorscale:
- color[1] = FigureFactory._label_rgb(
- color[1]
- )
- return colorscale
-
- @staticmethod
- def _make_linear_colorscale(colors):
- """
- Makes a list of colors into a colorscale-acceptable form
-
- For documentation regarding to the form of the output, see
- https://plot.ly/python/reference/#mesh3d-colorscale
- """
- scale = 1./(len(colors) - 1)
- return[[i * scale, color] for i, color in enumerate(colors)]
-
- @staticmethod
- def create_2D_density(x, y, colorscale='Earth', ncontours=20,
- hist_color=(0, 0, 0.5), point_color=(0, 0, 0.5),
- point_size=2, title='2D Density Plot',
- height=600, width=600):
- """
- Returns figure for a 2D density plot
-
- :param (list|array) x: x-axis data for plot generation
- :param (list|array) y: y-axis data for plot generation
- :param (str|tuple|list) colorscale: either a plotly scale name, an rgb
- or hex color, a color tuple or a list or tuple of colors. An rgb
- color is of the form 'rgb(x, y, z)' where x, y, z belong to the
- interval [0, 255] and a color tuple is a tuple of the form
- (a, b, c) where a, b and c belong to [0, 1]. If colormap is a
- list, it must contain the valid color types aforementioned as its
- members.
- :param (int) ncontours: the number of 2D contours to draw on the plot
- :param (str) hist_color: the color of the plotted histograms
- :param (str) point_color: the color of the scatter points
- :param (str) point_size: the color of the scatter points
- :param (str) title: set the title for the plot
- :param (float) height: the height of the chart
- :param (float) width: the width of the chart
-
- Example 1: Simple 2D Density Plot
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
-
- import numpy as np
-
- # Make data points
- t = np.linspace(-1,1.2,2000)
- x = (t**3)+(0.3*np.random.randn(2000))
- y = (t**6)+(0.3*np.random.randn(2000))
-
- # Create a figure
- fig = FF.create_2D_density(x, y)
-
- # Plot the data
- py.iplot(fig, filename='simple-2d-density')
- ```
-
- Example 2: Using Parameters
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
-
- import numpy as np
-
- # Make data points
- t = np.linspace(-1,1.2,2000)
- x = (t**3)+(0.3*np.random.randn(2000))
- y = (t**6)+(0.3*np.random.randn(2000))
-
- # Create custom colorscale
- colorscale = ['#7A4579', '#D56073', 'rgb(236,158,105)',
- (1, 1, 0.2), (0.98,0.98,0.98)]
-
- # Create a figure
- fig = FF.create_2D_density(
- x, y, colorscale=colorscale,
- hist_color='rgb(255, 237, 222)', point_size=3)
-
- # Plot the data
- py.iplot(fig, filename='use-parameters')
- ```
- """
- from plotly.graph_objs import graph_objs
- from numbers import Number
-
- # validate x and y are filled with numbers only
- for array in [x, y]:
- if not all(isinstance(element, Number) for element in array):
- raise exceptions.PlotlyError(
- "All elements of your 'x' and 'y' lists must be numbers."
- )
-
- # validate x and y are the same length
- if len(x) != len(y):
- raise exceptions.PlotlyError(
- "Both lists 'x' and 'y' must be the same length."
- )
-
- colorscale = FigureFactory._validate_colors(colorscale, 'rgb')
- colorscale = FigureFactory._make_linear_colorscale(colorscale)
-
- # validate hist_color and point_color
- hist_color = FigureFactory._validate_colors(hist_color, 'rgb')
- point_color = FigureFactory._validate_colors(point_color, 'rgb')
-
- trace1 = graph_objs.Scatter(
- x=x, y=y, mode='markers', name='points',
- marker=dict(
- color=point_color[0],
- size=point_size,
- opacity=0.4
- )
- )
- trace2 = graph_objs.Histogram2dcontour(
- x=x, y=y, name='density', ncontours=ncontours,
- colorscale=colorscale, reversescale=True, showscale=False
- )
- trace3 = graph_objs.Histogram(
- x=x, name='x density',
- marker=dict(color=hist_color[0]), yaxis='y2'
- )
- trace4 = graph_objs.Histogram(
- y=y, name='y density',
- marker=dict(color=hist_color[0]), xaxis='x2'
- )
- data = [trace1, trace2, trace3, trace4]
-
- layout = graph_objs.Layout(
- showlegend=False,
- autosize=False,
- title=title,
- height=height,
- width=width,
- xaxis=dict(
- domain=[0, 0.85],
- showgrid=False,
- zeroline=False
- ),
- yaxis=dict(
- domain=[0, 0.85],
- showgrid=False,
- zeroline=False
- ),
- margin=dict(
- t=50
- ),
- hovermode='closest',
- bargap=0,
- xaxis2=dict(
- domain=[0.85, 1],
- showgrid=False,
- zeroline=False
- ),
- yaxis2=dict(
- domain=[0.85, 1],
- showgrid=False,
- zeroline=False
- )
- )
-
- fig = graph_objs.Figure(data=data, layout=layout)
- return fig
-
- @staticmethod
- def _validate_gantt(df):
- """
- Validates the inputted dataframe or list
- """
- if _pandas_imported and isinstance(df, pd.core.frame.DataFrame):
- # validate that df has all the required keys
- for key in REQUIRED_GANTT_KEYS:
- if key not in df:
- raise exceptions.PlotlyError(
- "The columns in your dataframe must include the "
- "following keys: {0}".format(', '.join(REQUIRED_GANTT_KEYS))
- )
-
- num_of_rows = len(df.index)
- chart = []
- for index in range(num_of_rows):
- task_dict = {}
- for key in df:
- task_dict[key] = df.ix[index][key]
- chart.append(task_dict)
-
- return chart
-
- # validate if df is a list
- if not isinstance(df, list):
- raise exceptions.PlotlyError("You must input either a dataframe "
- "or a list of dictionaries.")
-
- # validate if df is empty
- if len(df) <= 0:
- raise exceptions.PlotlyError("Your list is empty. It must contain "
- "at least one dictionary.")
- if not isinstance(df[0], dict):
- raise exceptions.PlotlyError("Your list must only "
- "include dictionaries.")
- return df
-
- @staticmethod
- def _gantt(chart, colors, title, bar_width, showgrid_x, showgrid_y,
- height, width, tasks=None, task_names=None, data=None, group_tasks=False):
- """
- Refer to FigureFactory.create_gantt() for docstring
- """
- if tasks is None:
- tasks = []
- if task_names is None:
- task_names = []
- if data is None:
- data = []
-
- for index in range(len(chart)):
- task = dict(x0=chart[index]['Start'],
- x1=chart[index]['Finish'],
- name=chart[index]['Task'])
- if 'Description' in chart[index]:
- task['description'] = chart[index]['Description']
- tasks.append(task)
-
- shape_template = {
- 'type': 'rect',
- 'xref': 'x',
- 'yref': 'y',
- 'opacity': 1,
- 'line': {
- 'width': 0,
- },
- 'yref': 'y',
- }
- # create the list of task names
- for index in range(len(tasks)):
- tn = tasks[index]['name']
- # Is added to task_names if group_tasks is set to False,
- # or if the option is used (True) it only adds them if the
- # name is not already in the list
- if not group_tasks or tn not in task_names:
- task_names.append(tn)
- # Guarantees that for grouped tasks the tasks that are inserted first
- # are shown at the top
- if group_tasks:
- task_names.reverse()
-
-
- color_index = 0
- for index in range(len(tasks)):
- tn = tasks[index]['name']
- del tasks[index]['name']
- tasks[index].update(shape_template)
-
- # If group_tasks is True, all tasks with the same name belong
- # to the same row.
- groupID = index
- if group_tasks:
- groupID = task_names.index(tn)
- tasks[index]['y0'] = groupID - bar_width
- tasks[index]['y1'] = groupID + bar_width
-
- # check if colors need to be looped
- if color_index >= len(colors):
- color_index = 0
- tasks[index]['fillcolor'] = colors[color_index]
- # Add a line for hover text and autorange
- entry = dict(
- x=[tasks[index]['x0'], tasks[index]['x1']],
- y=[groupID, groupID],
- name='',
- marker={'color': 'white'}
- )
- if "description" in tasks[index]:
- entry['text'] = tasks[index]['description']
- del tasks[index]['description']
- data.append(entry)
- color_index += 1
-
- layout = dict(
- title=title,
- showlegend=False,
- height=height,
- width=width,
- shapes=[],
- hovermode='closest',
- yaxis=dict(
- showgrid=showgrid_y,
- ticktext=task_names,
- tickvals=list(range(len(task_names))),
- range=[-1, len(task_names) + 1],
- autorange=False,
- zeroline=False,
- ),
- xaxis=dict(
- showgrid=showgrid_x,
- zeroline=False,
- rangeselector=dict(
- buttons=list([
- dict(count=7,
- label='1w',
- step='day',
- stepmode='backward'),
- dict(count=1,
- label='1m',
- step='month',
- stepmode='backward'),
- dict(count=6,
- label='6m',
- step='month',
- stepmode='backward'),
- dict(count=1,
- label='YTD',
- step='year',
- stepmode='todate'),
- dict(count=1,
- label='1y',
- step='year',
- stepmode='backward'),
- dict(step='all')
- ])
- ),
- type='date'
- )
- )
- layout['shapes'] = tasks
-
- fig = dict(data=data, layout=layout)
- return fig
-
- @staticmethod
- def _gantt_colorscale(chart, colors, title, index_col, show_colorbar,
- bar_width, showgrid_x, showgrid_y, height,
- width, tasks=None, task_names=None, data=None, group_tasks=False):
- """
- Refer to FigureFactory.create_gantt() for docstring
- """
- from numbers import Number
- if tasks is None:
- tasks = []
- if task_names is None:
- task_names = []
- if data is None:
- data = []
- showlegend = False
-
- for index in range(len(chart)):
- task = dict(x0=chart[index]['Start'],
- x1=chart[index]['Finish'],
- name=chart[index]['Task'])
- if 'Description' in chart[index]:
- task['description'] = chart[index]['Description']
- tasks.append(task)
-
- shape_template = {
- 'type': 'rect',
- 'xref': 'x',
- 'yref': 'y',
- 'opacity': 1,
- 'line': {
- 'width': 0,
- },
- 'yref': 'y',
- }
-
- # compute the color for task based on indexing column
- if isinstance(chart[0][index_col], Number):
- # check that colors has at least 2 colors
- if len(colors) < 2:
- raise exceptions.PlotlyError(
- "You must use at least 2 colors in 'colors' if you "
- "are using a colorscale. However only the first two "
- "colors given will be used for the lower and upper "
- "bounds on the colormap."
- )
-
- # create the list of task names
- for index in range(len(tasks)):
- tn = tasks[index]['name']
- # Is added to task_names if group_tasks is set to False,
- # or if the option is used (True) it only adds them if the
- # name is not already in the list
- if not group_tasks or tn not in task_names:
- task_names.append(tn)
- # Guarantees that for grouped tasks the tasks that are inserted
- # first are shown at the top
- if group_tasks:
- task_names.reverse()
-
- for index in range(len(tasks)):
- tn = tasks[index]['name']
- del tasks[index]['name']
- tasks[index].update(shape_template)
-
- # If group_tasks is True, all tasks with the same name belong
- # to the same row.
- groupID = index
- if group_tasks:
- groupID = task_names.index(tn)
- tasks[index]['y0'] = groupID - bar_width
- tasks[index]['y1'] = groupID + bar_width
-
- # unlabel color
- colors = FigureFactory._color_parser(
- colors, FigureFactory._unlabel_rgb
- )
- lowcolor = colors[0]
- highcolor = colors[1]
-
- intermed = (chart[index][index_col])/100.0
- intermed_color = FigureFactory._find_intermediate_color(
- lowcolor, highcolor, intermed
- )
- intermed_color = FigureFactory._color_parser(
- intermed_color, FigureFactory._label_rgb
- )
- tasks[index]['fillcolor'] = intermed_color
- # relabel colors with 'rgb'
- colors = FigureFactory._color_parser(
- colors, FigureFactory._label_rgb
- )
-
- # add a line for hover text and autorange
- entry = dict(
- x=[tasks[index]['x0'], tasks[index]['x1']],
- y=[groupID, groupID],
- name='',
- marker={'color': 'white'}
- )
- if "description" in tasks[index]:
- entry['text'] = tasks[index]['description']
- del tasks[index]['description']
- data.append(entry)
-
-
- if show_colorbar is True:
- # generate dummy data for colorscale visibility
- data.append(
- dict(
- x=[tasks[index]['x0'], tasks[index]['x0']],
- y=[index, index],
- name='',
- marker={'color': 'white',
- 'colorscale': [[0, colors[0]], [1, colors[1]]],
- 'showscale': True,
- 'cmax': 100,
- 'cmin': 0}
- )
- )
-
- if isinstance(chart[0][index_col], str):
- index_vals = []
- for row in range(len(tasks)):
- if chart[row][index_col] not in index_vals:
- index_vals.append(chart[row][index_col])
-
- index_vals.sort()
-
- if len(colors) < len(index_vals):
- raise exceptions.PlotlyError(
- "Error. The number of colors in 'colors' must be no less "
- "than the number of unique index values in your group "
- "column."
- )
-
- # make a dictionary assignment to each index value
- index_vals_dict = {}
- # define color index
- c_index = 0
- for key in index_vals:
- if c_index > len(colors) - 1:
- c_index = 0
- index_vals_dict[key] = colors[c_index]
- c_index += 1
-
- # create the list of task names
- for index in range(len(tasks)):
- tn = tasks[index]['name']
- # Is added to task_names if group_tasks is set to False,
- # or if the option is used (True) it only adds them if the
- # name is not already in the list
- if not group_tasks or tn not in task_names:
- task_names.append(tn)
- # Guarantees that for grouped tasks the tasks that are inserted
- # first are shown at the top
- if group_tasks:
- task_names.reverse()
-
- for index in range(len(tasks)):
- tn = tasks[index]['name']
- del tasks[index]['name']
- tasks[index].update(shape_template)
- # If group_tasks is True, all tasks with the same name belong
- # to the same row.
- groupID = index
- if group_tasks:
- groupID = task_names.index(tn)
- tasks[index]['y0'] = groupID - bar_width
- tasks[index]['y1'] = groupID + bar_width
-
- tasks[index]['fillcolor'] = index_vals_dict[
- chart[index][index_col]
- ]
-
- # add a line for hover text and autorange
- entry = dict(
- x=[tasks[index]['x0'], tasks[index]['x1']],
- y=[groupID, groupID],
- name='',
- marker={'color': 'white'}
- )
- if "description" in tasks[index]:
- entry['text'] = tasks[index]['description']
- del tasks[index]['description']
- data.append(entry)
-
- if show_colorbar is True:
- # generate dummy data to generate legend
- showlegend = True
- for k, index_value in enumerate(index_vals):
- data.append(
- dict(
- x=[tasks[index]['x0'], tasks[index]['x0']],
- y=[k, k],
- showlegend=True,
- name=str(index_value),
- hoverinfo='none',
- marker=dict(
- color=colors[k],
- size=1
- )
- )
- )
-
- layout = dict(
- title=title,
- showlegend=showlegend,
- height=height,
- width=width,
- shapes=[],
- hovermode='closest',
- yaxis=dict(
- showgrid=showgrid_y,
- ticktext=task_names,
- tickvals=list(range(len(task_names))),
- range=[-1, len(task_names) + 1],
- autorange=False,
- zeroline=False,
- ),
- xaxis=dict(
- showgrid=showgrid_x,
- zeroline=False,
- rangeselector=dict(
- buttons=list([
- dict(count=7,
- label='1w',
- step='day',
- stepmode='backward'),
- dict(count=1,
- label='1m',
- step='month',
- stepmode='backward'),
- dict(count=6,
- label='6m',
- step='month',
- stepmode='backward'),
- dict(count=1,
- label='YTD',
- step='year',
- stepmode='todate'),
- dict(count=1,
- label='1y',
- step='year',
- stepmode='backward'),
- dict(step='all')
- ])
- ),
- type='date'
- )
- )
- layout['shapes'] = tasks
-
- fig = dict(data=data, layout=layout)
- return fig
-
- @staticmethod
- def _gantt_dict(chart, colors, title, index_col, show_colorbar, bar_width,
- showgrid_x, showgrid_y, height, width, tasks=None,
- task_names=None, data=None, group_tasks=False):
- """
- Refer to FigureFactory.create_gantt() for docstring
- """
- if tasks is None:
- tasks = []
- if task_names is None:
- task_names = []
- if data is None:
- data = []
- showlegend = False
-
- for index in range(len(chart)):
- task = dict(x0=chart[index]['Start'],
- x1=chart[index]['Finish'],
- name=chart[index]['Task'])
- if 'Description' in chart[index]:
- task['description'] = chart[index]['Description']
- tasks.append(task)
-
- shape_template = {
- 'type': 'rect',
- 'xref': 'x',
- 'yref': 'y',
- 'opacity': 1,
- 'line': {
- 'width': 0,
- },
- 'yref': 'y',
- }
-
- index_vals = []
- for row in range(len(tasks)):
- if chart[row][index_col] not in index_vals:
- index_vals.append(chart[row][index_col])
-
- index_vals.sort()
-
- # verify each value in index column appears in colors dictionary
- for key in index_vals:
- if key not in colors:
- raise exceptions.PlotlyError(
- "If you are using colors as a dictionary, all of its "
- "keys must be all the values in the index column."
- )
-
- # create the list of task names
- for index in range(len(tasks)):
- tn = tasks[index]['name']
- # Is added to task_names if group_tasks is set to False,
- # or if the option is used (True) it only adds them if the
- # name is not already in the list
- if not group_tasks or tn not in task_names:
- task_names.append(tn)
- # Guarantees that for grouped tasks the tasks that are inserted first
- # are shown at the top
- if group_tasks:
- task_names.reverse()
-
- for index in range(len(tasks)):
- tn = tasks[index]['name']
- del tasks[index]['name']
- tasks[index].update(shape_template)
-
- # If group_tasks is True, all tasks with the same name belong
- # to the same row.
- groupID = index
- if group_tasks:
- groupID = task_names.index(tn)
- tasks[index]['y0'] = groupID - bar_width
- tasks[index]['y1'] = groupID + bar_width
-
- tasks[index]['fillcolor'] = colors[chart[index][index_col]]
-
- # add a line for hover text and autorange
- entry = dict(
- x=[tasks[index]['x0'], tasks[index]['x1']],
- y=[groupID, groupID],
- name='',
- marker={'color': 'white'}
- )
- if "description" in tasks[index]:
- entry['text'] = tasks[index]['description']
- del tasks[index]['description']
- data.append(entry)
-
- if show_colorbar is True:
- # generate dummy data to generate legend
- showlegend = True
- for k, index_value in enumerate(index_vals):
- data.append(
- dict(
- x=[tasks[index]['x0'], tasks[index]['x0']],
- y=[k, k],
- showlegend=True,
- hoverinfo='none',
- name=str(index_value),
- marker=dict(
- color=colors[index_value],
- size=1
- )
- )
- )
-
- layout = dict(
- title=title,
- showlegend=showlegend,
- height=height,
- width=width,
- shapes=[],
- hovermode='closest',
- yaxis=dict(
- showgrid=showgrid_y,
- ticktext=task_names,
- tickvals=list(range(len(task_names))),
- range=[-1, len(task_names) + 1],
- autorange=False,
- zeroline=False,
- ),
- xaxis=dict(
- showgrid=showgrid_x,
- zeroline=False,
- rangeselector=dict(
- buttons=list([
- dict(count=7,
- label='1w',
- step='day',
- stepmode='backward'),
- dict(count=1,
- label='1m',
- step='month',
- stepmode='backward'),
- dict(count=6,
- label='6m',
- step='month',
- stepmode='backward'),
- dict(count=1,
- label='YTD',
- step='year',
- stepmode='todate'),
- dict(count=1,
- label='1y',
- step='year',
- stepmode='backward'),
- dict(step='all')
- ])
- ),
- type='date'
- )
- )
- layout['shapes'] = tasks
-
- fig = dict(data=data, layout=layout)
- return fig
-
- @staticmethod
- def create_gantt(df, colors=None, index_col=None, show_colorbar=False,
- reverse_colors=False, title='Gantt Chart',
- bar_width=0.2, showgrid_x=False, showgrid_y=False,
- height=600, width=900, tasks=None,
- task_names=None, data=None, group_tasks=False):
- """
- Returns figure for a gantt chart
-
- :param (array|list) df: input data for gantt chart. Must be either a
- a dataframe or a list. If dataframe, the columns must include
- 'Task', 'Start' and 'Finish'. Other columns can be included and
- used for indexing. If a list, its elements must be dictionaries
- with the same required column headers: 'Task', 'Start' and
- 'Finish'.
- :param (str|list|dict|tuple) colors: either a plotly scale name, an
- rgb or hex color, a color tuple or a list of colors. An rgb color
- is of the form 'rgb(x, y, z)' where x, y, z belong to the interval
- [0, 255] and a color tuple is a tuple of the form (a, b, c) where
- a, b and c belong to [0, 1]. If colors is a list, it must
- contain the valid color types aforementioned as its members.
- If a dictionary, all values of the indexing column must be keys in
- colors.
- :param (str|float) index_col: the column header (if df is a data
- frame) that will function as the indexing column. If df is a list,
- index_col must be one of the keys in all the items of df.
- :param (bool) show_colorbar: determines if colorbar will be visible.
- Only applies if values in the index column are numeric.
- :param (bool) reverse_colors: reverses the order of selected colors
- :param (str) title: the title of the chart
- :param (float) bar_width: the width of the horizontal bars in the plot
- :param (bool) showgrid_x: show/hide the x-axis grid
- :param (bool) showgrid_y: show/hide the y-axis grid
- :param (float) height: the height of the chart
- :param (float) width: the width of the chart
-
- Example 1: Simple Gantt Chart
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
-
- # Make data for chart
- df = [dict(Task="Job A", Start='2009-01-01', Finish='2009-02-30'),
- dict(Task="Job B", Start='2009-03-05', Finish='2009-04-15'),
- dict(Task="Job C", Start='2009-02-20', Finish='2009-05-30')]
-
- # Create a figure
- fig = FF.create_gantt(df)
-
- # Plot the data
- py.iplot(fig, filename='Simple Gantt Chart', world_readable=True)
- ```
-
- Example 2: Index by Column with Numerical Entries
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
-
- # Make data for chart
- df = [dict(Task="Job A", Start='2009-01-01',
- Finish='2009-02-30', Complete=10),
- dict(Task="Job B", Start='2009-03-05',
- Finish='2009-04-15', Complete=60),
- dict(Task="Job C", Start='2009-02-20',
- Finish='2009-05-30', Complete=95)]
-
- # Create a figure with Plotly colorscale
- fig = FF.create_gantt(df, colors='Blues', index_col='Complete',
- show_colorbar=True, bar_width=0.5,
- showgrid_x=True, showgrid_y=True)
-
- # Plot the data
- py.iplot(fig, filename='Numerical Entries', world_readable=True)
- ```
-
- Example 3: Index by Column with String Entries
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
-
- # Make data for chart
- df = [dict(Task="Job A", Start='2009-01-01',
- Finish='2009-02-30', Resource='Apple'),
- dict(Task="Job B", Start='2009-03-05',
- Finish='2009-04-15', Resource='Grape'),
- dict(Task="Job C", Start='2009-02-20',
- Finish='2009-05-30', Resource='Banana')]
-
- # Create a figure with Plotly colorscale
- fig = FF.create_gantt(df, colors=['rgb(200, 50, 25)',
- (1, 0, 1),
- '#6c4774'],
- index_col='Resource',
- reverse_colors=True,
- show_colorbar=True)
-
- # Plot the data
- py.iplot(fig, filename='String Entries', world_readable=True)
- ```
-
- Example 4: Use a dictionary for colors
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
-
- # Make data for chart
- df = [dict(Task="Job A", Start='2009-01-01',
- Finish='2009-02-30', Resource='Apple'),
- dict(Task="Job B", Start='2009-03-05',
- Finish='2009-04-15', Resource='Grape'),
- dict(Task="Job C", Start='2009-02-20',
- Finish='2009-05-30', Resource='Banana')]
-
- # Make a dictionary of colors
- colors = {'Apple': 'rgb(255, 0, 0)',
- 'Grape': 'rgb(170, 14, 200)',
- 'Banana': (1, 1, 0.2)}
-
- # Create a figure with Plotly colorscale
- fig = FF.create_gantt(df, colors=colors,
- index_col='Resource',
- show_colorbar=True)
-
- # Plot the data
- py.iplot(fig, filename='dictioanry colors', world_readable=True)
- ```
-
- Example 5: Use a pandas dataframe
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
-
- import pandas as pd
-
- # Make data as a dataframe
- df = pd.DataFrame([['Run', '2010-01-01', '2011-02-02', 10],
- ['Fast', '2011-01-01', '2012-06-05', 55],
- ['Eat', '2012-01-05', '2013-07-05', 94]],
- columns=['Task', 'Start', 'Finish', 'Complete'])
-
- # Create a figure with Plotly colorscale
- fig = FF.create_gantt(df, colors='Blues', index_col='Complete',
- show_colorbar=True, bar_width=0.5,
- showgrid_x=True, showgrid_y=True)
-
- # Plot the data
- py.iplot(fig, filename='data with dataframe', world_readable=True)
- ```
- """
- # validate gantt input data
- chart = FigureFactory._validate_gantt(df)
-
- if index_col:
- if index_col not in chart[0]:
- raise exceptions.PlotlyError(
- "In order to use an indexing column and assign colors to "
- "the values of the index, you must choose an actual "
- "column name in the dataframe or key if a list of "
- "dictionaries is being used.")
-
- # validate gantt index column
- index_list = []
- for dictionary in chart:
- index_list.append(dictionary[index_col])
- FigureFactory._validate_index(index_list)
-
- # Validate colors
- if isinstance(colors, dict):
- colors = FigureFactory._validate_colors_dict(colors, 'rgb')
- else:
- colors = FigureFactory._validate_colors(colors, 'rgb')
-
- if reverse_colors is True:
- colors.reverse()
-
- if not index_col:
- if isinstance(colors, dict):
- raise exceptions.PlotlyError(
- "Error. You have set colors to a dictionary but have not "
- "picked an index. An index is required if you are "
- "assigning colors to particular values in a dictioanry."
- )
- fig = FigureFactory._gantt(
- chart, colors, title, bar_width, showgrid_x, showgrid_y,
- height, width, tasks=None, task_names=None, data=None,
- group_tasks=group_tasks
- )
- return fig
- else:
- if not isinstance(colors, dict):
- fig = FigureFactory._gantt_colorscale(
- chart, colors, title, index_col, show_colorbar, bar_width,
- showgrid_x, showgrid_y, height, width,
- tasks=None, task_names=None, data=None, group_tasks=group_tasks
- )
- return fig
- else:
- fig = FigureFactory._gantt_dict(
- chart, colors, title, index_col, show_colorbar, bar_width,
- showgrid_x, showgrid_y, height, width,
- tasks=None, task_names=None, data=None, group_tasks=group_tasks
- )
- return fig
-
- @staticmethod
- def _validate_colors(colors, colortype='tuple'):
- """
- Validates color(s) and returns a list of color(s) of a specified type
- """
- from numbers import Number
- if colors is None:
- colors = DEFAULT_PLOTLY_COLORS
-
- if isinstance(colors, str):
- if colors in PLOTLY_SCALES:
- colors = PLOTLY_SCALES[colors]
- elif 'rgb' in colors or '#' in colors:
- colors = [colors]
- else:
- raise exceptions.PlotlyError(
- "If your colors variable is a string, it must be a "
- "Plotly scale, an rgb color or a hex color.")
-
- elif isinstance(colors, tuple):
- if isinstance(colors[0], Number):
- colors = [colors]
- else:
- colors = list(colors)
-
-
- # convert color elements in list to tuple color
- for j, each_color in enumerate(colors):
- if 'rgb' in each_color:
- each_color = FigureFactory._color_parser(
- each_color, FigureFactory._unlabel_rgb
- )
- for value in each_color:
- if value > 255.0:
- raise exceptions.PlotlyError(
- "Whoops! The elements in your rgb colors "
- "tuples cannot exceed 255.0."
- )
- each_color = FigureFactory._color_parser(
- each_color, FigureFactory._unconvert_from_RGB_255
- )
- colors[j] = each_color
-
- if '#' in each_color:
- each_color = FigureFactory._color_parser(
- each_color, FigureFactory._hex_to_rgb
- )
- each_color = FigureFactory._color_parser(
- each_color, FigureFactory._unconvert_from_RGB_255
- )
-
- colors[j] = each_color
-
- if isinstance(each_color, tuple):
- for value in each_color:
- if value > 1.0:
- raise exceptions.PlotlyError(
- "Whoops! The elements in your colors tuples "
- "cannot exceed 1.0."
- )
- colors[j] = each_color
-
- if colortype == 'rgb':
- for j, each_color in enumerate(colors):
- rgb_color = FigureFactory._color_parser(
- each_color, FigureFactory._convert_to_RGB_255
- )
- colors[j] = FigureFactory._color_parser(
- rgb_color, FigureFactory._label_rgb
- )
-
- return colors
-
- @staticmethod
- def _validate_colors_dict(colors, colortype='tuple'):
- """
- Validates dictioanry of color(s)
- """
- # validate each color element in the dictionary
- for key in colors:
- if 'rgb' in colors[key]:
- colors[key] = FigureFactory._color_parser(
- colors[key], FigureFactory._unlabel_rgb
- )
- for value in colors[key]:
- if value > 255.0:
- raise exceptions.PlotlyError(
- "Whoops! The elements in your rgb colors "
- "tuples cannot exceed 255.0."
- )
- colors[key] = FigureFactory._color_parser(
- colors[key], FigureFactory._unconvert_from_RGB_255
- )
-
- if '#' in colors[key]:
- colors[key] = FigureFactory._color_parser(
- colors[key], FigureFactory._hex_to_rgb
- )
- colors[key] = FigureFactory._color_parser(
- colors[key], FigureFactory._unconvert_from_RGB_255
- )
-
- if isinstance(colors[key], tuple):
- for value in colors[key]:
- if value > 1.0:
- raise exceptions.PlotlyError(
- "Whoops! The elements in your colors tuples "
- "cannot exceed 1.0."
- )
-
- if colortype == 'rgb':
- for key in colors:
- colors[key] = FigureFactory._color_parser(
- colors[key], FigureFactory._convert_to_RGB_255
- )
- colors[key] = FigureFactory._color_parser(
- colors[key], FigureFactory._label_rgb
- )
-
- return colors
-
- @staticmethod
- def _calc_stats(data):
- """
- Calculate statistics for use in violin plot.
- """
- import numpy as np
-
- x = np.asarray(data, np.float)
- vals_min = np.min(x)
- vals_max = np.max(x)
- q2 = np.percentile(x, 50, interpolation='linear')
- q1 = np.percentile(x, 25, interpolation='lower')
- q3 = np.percentile(x, 75, interpolation='higher')
- iqr = q3 - q1
- whisker_dist = 1.5 * iqr
-
- # in order to prevent drawing whiskers outside the interval
- # of data one defines the whisker positions as:
- d1 = np.min(x[x >= (q1 - whisker_dist)])
- d2 = np.max(x[x <= (q3 + whisker_dist)])
- return {
- 'min': vals_min,
- 'max': vals_max,
- 'q1': q1,
- 'q2': q2,
- 'q3': q3,
- 'd1': d1,
- 'd2': d2
- }
-
- @staticmethod
- def _make_half_violin(x, y, fillcolor='#1f77b4',
- linecolor='rgb(0, 0, 0)'):
- """
- Produces a sideways probability distribution fig violin plot.
- """
- from plotly.graph_objs import graph_objs
-
- text = ['(pdf(y), y)=(' + '{:0.2f}'.format(x[i]) +
- ', ' + '{:0.2f}'.format(y[i]) + ')'
- for i in range(len(x))]
-
- return graph_objs.Scatter(
- x=x,
- y=y,
- mode='lines',
- name='',
- text=text,
- fill='tonextx',
- fillcolor=fillcolor,
- line=graph_objs.Line(width=0.5, color=linecolor, shape='spline'),
- hoverinfo='text',
- opacity=0.5
- )
-
- @staticmethod
- def _make_violin_rugplot(vals, pdf_max, distance,
- color='#1f77b4'):
- """
- Returns a rugplot fig for a violin plot.
- """
- from plotly.graph_objs import graph_objs
-
- return graph_objs.Scatter(
- y=vals,
- x=[-pdf_max-distance]*len(vals),
- marker=graph_objs.Marker(
- color=color,
- symbol='line-ew-open'
- ),
- mode='markers',
- name='',
- showlegend=False,
- hoverinfo='y'
- )
-
- @staticmethod
- def _make_quartiles(q1, q3):
- """
- Makes the upper and lower quartiles for a violin plot.
- """
- from plotly.graph_objs import graph_objs
-
- return graph_objs.Scatter(
- x=[0, 0],
- y=[q1, q3],
- text=['lower-quartile: ' + '{:0.2f}'.format(q1),
- 'upper-quartile: ' + '{:0.2f}'.format(q3)],
- mode='lines',
- line=graph_objs.Line(
- width=4,
- color='rgb(0,0,0)'
- ),
- hoverinfo='text'
- )
-
- @staticmethod
- def _make_median(q2):
- """
- Formats the 'median' hovertext for a violin plot.
- """
- from plotly.graph_objs import graph_objs
-
- return graph_objs.Scatter(
- x=[0],
- y=[q2],
- text=['median: ' + '{:0.2f}'.format(q2)],
- mode='markers',
- marker=dict(symbol='square',
- color='rgb(255,255,255)'),
- hoverinfo='text'
- )
@staticmethod
- def _make_non_outlier_interval(d1, d2):
- """
- Returns the scatterplot fig of most of a violin plot.
- """
- from plotly.graph_objs import graph_objs
-
- return graph_objs.Scatter(
- x=[0, 0],
- y=[d1, d2],
- name='',
- mode='lines',
- line=graph_objs.Line(width=1.5,
- color='rgb(0,0,0)')
+ def _deprecated(old_method, new_method=None):
+ if new_method is None:
+ # The method name stayed the same.
+ new_method = old_method
+ warnings.warn(
+ 'plotly.tools.FigureFactory.{} is deprecated. '
+ 'Use plotly.figure_factory.{}'.format(old_method, new_method)
)
@staticmethod
- def _make_XAxis(xaxis_title, xaxis_range):
- """
- Makes the x-axis for a violin plot.
- """
- from plotly.graph_objs import graph_objs
-
- xaxis = graph_objs.XAxis(title=xaxis_title,
- range=xaxis_range,
- showgrid=False,
- zeroline=False,
- showline=False,
- mirror=False,
- ticks='',
- showticklabels=False,
- )
- return xaxis
-
- @staticmethod
- def _make_YAxis(yaxis_title):
- """
- Makes the y-axis for a violin plot.
- """
- from plotly.graph_objs import graph_objs
-
- yaxis = graph_objs.YAxis(title=yaxis_title,
- showticklabels=True,
- autorange=True,
- ticklen=4,
- showline=True,
- zeroline=False,
- showgrid=False,
- mirror=False)
- return yaxis
-
- @staticmethod
- def _violinplot(vals, fillcolor='#1f77b4', rugplot=True):
- """
- Refer to FigureFactory.create_violin() for docstring.
- """
- import numpy as np
- from scipy import stats
-
- vals = np.asarray(vals, np.float)
- # summary statistics
- vals_min = FigureFactory._calc_stats(vals)['min']
- vals_max = FigureFactory._calc_stats(vals)['max']
- q1 = FigureFactory._calc_stats(vals)['q1']
- q2 = FigureFactory._calc_stats(vals)['q2']
- q3 = FigureFactory._calc_stats(vals)['q3']
- d1 = FigureFactory._calc_stats(vals)['d1']
- d2 = FigureFactory._calc_stats(vals)['d2']
-
- # kernel density estimation of pdf
- pdf = stats.gaussian_kde(vals)
- # grid over the data interval
- xx = np.linspace(vals_min, vals_max, 100)
- # evaluate the pdf at the grid xx
- yy = pdf(xx)
- max_pdf = np.max(yy)
- # distance from the violin plot to rugplot
- distance = (2.0 * max_pdf)/10 if rugplot else 0
- # range for x values in the plot
- plot_xrange = [-max_pdf - distance - 0.1, max_pdf + 0.1]
- plot_data = [FigureFactory._make_half_violin(
- -yy, xx, fillcolor=fillcolor),
- FigureFactory._make_half_violin(
- yy, xx, fillcolor=fillcolor),
- FigureFactory._make_non_outlier_interval(d1, d2),
- FigureFactory._make_quartiles(q1, q3),
- FigureFactory._make_median(q2)]
- if rugplot:
- plot_data.append(FigureFactory._make_violin_rugplot(
- vals,
- max_pdf,
- distance=distance,
- color=fillcolor)
- )
- return plot_data, plot_xrange
-
- @staticmethod
- def _violin_no_colorscale(data, data_header, group_header, colors,
- use_colorscale, group_stats,
- height, width, title):
- """
- Refer to FigureFactory.create_violin() for docstring.
-
- Returns fig for violin plot without colorscale.
-
- """
- from plotly.graph_objs import graph_objs
- import numpy as np
-
- # collect all group names
- group_name = []
- for name in data[group_header]:
- if name not in group_name:
- group_name.append(name)
- group_name.sort()
-
- gb = data.groupby([group_header])
- L = len(group_name)
-
- fig = make_subplots(rows=1, cols=L,
- shared_yaxes=True,
- horizontal_spacing=0.025,
- print_grid=False)
- color_index = 0
- for k, gr in enumerate(group_name):
- vals = np.asarray(gb.get_group(gr)[data_header], np.float)
- if color_index >= len(colors):
- color_index = 0
- plot_data, plot_xrange = FigureFactory._violinplot(
- vals,
- fillcolor=colors[color_index]
- )
- layout = graph_objs.Layout()
-
- for item in plot_data:
- fig.append_trace(item, 1, k + 1)
- color_index += 1
-
- # add violin plot labels
- fig['layout'].update({'xaxis{}'.format(k + 1):
- FigureFactory._make_XAxis(group_name[k],
- plot_xrange)})
-
- # set the sharey axis style
- fig['layout'].update(
- {'yaxis{}'.format(1): FigureFactory._make_YAxis('')}
- )
- fig['layout'].update(
- title=title,
- showlegend=False,
- hovermode='closest',
- autosize=False,
- height=height,
- width=width
- )
-
- return fig
+ def create_2D_density(*args, **kwargs):
+ FigureFactory._deprecated('create_2D_density', 'create_2d_density')
+ from plotly.figure_factory import create_2d_density
+ return create_2d_density(*args, **kwargs)
@staticmethod
- def _violin_colorscale(data, data_header, group_header, colors,
- use_colorscale, group_stats, height, width,
- title):
- """
- Refer to FigureFactory.create_violin() for docstring.
-
- Returns fig for violin plot with colorscale.
-
- """
- from plotly.graph_objs import graph_objs
- import numpy as np
-
- # collect all group names
- group_name = []
- for name in data[group_header]:
- if name not in group_name:
- group_name.append(name)
- group_name.sort()
-
- # make sure all group names are keys in group_stats
- for group in group_name:
- if group not in group_stats:
- raise exceptions.PlotlyError("All values/groups in the index "
- "column must be represented "
- "as a key in group_stats.")
-
- gb = data.groupby([group_header])
- L = len(group_name)
-
- fig = make_subplots(rows=1, cols=L,
- shared_yaxes=True,
- horizontal_spacing=0.025,
- print_grid=False)
-
- # prepare low and high color for colorscale
- lowcolor = FigureFactory._color_parser(
- colors[0], FigureFactory._unlabel_rgb
- )
- highcolor = FigureFactory._color_parser(
- colors[1], FigureFactory._unlabel_rgb
- )
-
- # find min and max values in group_stats
- group_stats_values = []
- for key in group_stats:
- group_stats_values.append(group_stats[key])
-
- max_value = max(group_stats_values)
- min_value = min(group_stats_values)
-
- for k, gr in enumerate(group_name):
- vals = np.asarray(gb.get_group(gr)[data_header], np.float)
-
- # find intermediate color from colorscale
- intermed = (group_stats[gr] - min_value) / (max_value - min_value)
- intermed_color = FigureFactory._find_intermediate_color(
- lowcolor, highcolor, intermed
- )
-
- plot_data, plot_xrange = FigureFactory._violinplot(
- vals,
- fillcolor='rgb{}'.format(intermed_color)
- )
- layout = graph_objs.Layout()
-
- for item in plot_data:
- fig.append_trace(item, 1, k + 1)
- fig['layout'].update({'xaxis{}'.format(k + 1):
- FigureFactory._make_XAxis(group_name[k],
- plot_xrange)})
- # add colorbar to plot
- trace_dummy = graph_objs.Scatter(
- x=[0],
- y=[0],
- mode='markers',
- marker=dict(
- size=2,
- cmin=min_value,
- cmax=max_value,
- colorscale=[[0, colors[0]],
- [1, colors[1]]],
- showscale=True),
- showlegend=False,
- )
- fig.append_trace(trace_dummy, 1, L)
-
- # set the sharey axis style
- fig['layout'].update(
- {'yaxis{}'.format(1): FigureFactory._make_YAxis('')}
- )
- fig['layout'].update(
- title=title,
- showlegend=False,
- hovermode='closest',
- autosize=False,
- height=height,
- width=width
- )
-
- return fig
+ def create_annotated_heatmap(*args, **kwargs):
+ FigureFactory._deprecated('create_annotated_heatmap')
+ from plotly.figure_factory import create_annotated_heatmap
+ return create_annotated_heatmap(*args, **kwargs)
@staticmethod
- def _violin_dict(data, data_header, group_header, colors, use_colorscale,
- group_stats, height, width, title):
- """
- Refer to FigureFactory.create_violin() for docstring.
-
- Returns fig for violin plot without colorscale.
-
- """
- from plotly.graph_objs import graph_objs
- import numpy as np
-
- # collect all group names
- group_name = []
- for name in data[group_header]:
- if name not in group_name:
- group_name.append(name)
- group_name.sort()
-
- # check if all group names appear in colors dict
- for group in group_name:
- if group not in colors:
- raise exceptions.PlotlyError("If colors is a dictionary, all "
- "the group names must appear as "
- "keys in colors.")
-
- gb = data.groupby([group_header])
- L = len(group_name)
-
- fig = make_subplots(rows=1, cols=L,
- shared_yaxes=True,
- horizontal_spacing=0.025,
- print_grid=False)
-
- for k, gr in enumerate(group_name):
- vals = np.asarray(gb.get_group(gr)[data_header], np.float)
- plot_data, plot_xrange = FigureFactory._violinplot(
- vals,
- fillcolor=colors[gr]
- )
- layout = graph_objs.Layout()
-
- for item in plot_data:
- fig.append_trace(item, 1, k + 1)
-
- # add violin plot labels
- fig['layout'].update({'xaxis{}'.format(k + 1):
- FigureFactory._make_XAxis(group_name[k],
- plot_xrange)})
-
- # set the sharey axis style
- fig['layout'].update(
- {'yaxis{}'.format(1): FigureFactory._make_YAxis('')}
- )
- fig['layout'].update(
- title=title,
- showlegend=False,
- hovermode='closest',
- autosize=False,
- height=height,
- width=width
- )
-
- return fig
+ def create_candlestick(*args, **kwargs):
+ FigureFactory._deprecated('create_candlestick')
+ from plotly.figure_factory import create_candlestick
+ return create_candlestick(*args, **kwargs)
@staticmethod
- def create_violin(data, data_header=None, group_header=None,
- colors=None, use_colorscale=False, group_stats=None,
- height=450, width=600, title='Violin and Rug Plot'):
- """
- Returns figure for a violin plot
-
- :param (list|array) data: accepts either a list of numerical values,
- a list of dictionaries all with identical keys and at least one
- column of numeric values, or a pandas dataframe with at least one
- column of numbers
- :param (str) data_header: the header of the data column to be used
- from an inputted pandas dataframe. Not applicable if 'data' is
- a list of numeric values
- :param (str) group_header: applicable if grouping data by a variable.
- 'group_header' must be set to the name of the grouping variable.
- :param (str|tuple|list|dict) colors: either a plotly scale name,
- an rgb or hex color, a color tuple, a list of colors or a
- dictionary. An rgb color is of the form 'rgb(x, y, z)' where
- x, y and z belong to the interval [0, 255] and a color tuple is a
- tuple of the form (a, b, c) where a, b and c belong to [0, 1].
- If colors is a list, it must contain valid color types as its
- members.
- :param (bool) use_colorscale: Only applicable if grouping by another
- variable. Will implement a colorscale based on the first 2 colors
- of param colors. This means colors must be a list with at least 2
- colors in it (Plotly colorscales are accepted since they map to a
- list of two rgb colors)
- :param (dict) group_stats: a dictioanry where each key is a unique
- value from the group_header column in data. Each value must be a
- number and will be used to color the violin plots if a colorscale
- is being used
- :param (float) height: the height of the violin plot
- :param (float) width: the width of the violin plot
- :param (str) title: the title of the violin plot
-
- Example 1: Single Violin Plot
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
- from plotly.graph_objs import graph_objs
-
- import numpy as np
- from scipy import stats
-
- # create list of random values
- data_list = np.random.randn(100)
- data_list.tolist()
-
- # create violin fig
- fig = FF.create_violin(data_list, colors='#604d9e')
-
- # plot
- py.iplot(fig, filename='Violin Plot')
- ```
-
- Example 2: Multiple Violin Plots with Qualitative Coloring
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
- from plotly.graph_objs import graph_objs
-
- import numpy as np
- import pandas as pd
- from scipy import stats
-
- # create dataframe
- np.random.seed(619517)
- Nr=250
- y = np.random.randn(Nr)
- gr = np.random.choice(list("ABCDE"), Nr)
- norm_params=[(0, 1.2), (0.7, 1), (-0.5, 1.4), (0.3, 1), (0.8, 0.9)]
-
- for i, letter in enumerate("ABCDE"):
- y[gr == letter] *=norm_params[i][1]+ norm_params[i][0]
- df = pd.DataFrame(dict(Score=y, Group=gr))
-
- # create violin fig
- fig = FF.create_violin(df, data_header='Score', group_header='Group',
- height=600, width=1000)
-
- # plot
- py.iplot(fig, filename='Violin Plot with Coloring')
- ```
-
- Example 3: Violin Plots with Colorscale
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
- from plotly.graph_objs import graph_objs
-
- import numpy as np
- import pandas as pd
- from scipy import stats
-
- # create dataframe
- np.random.seed(619517)
- Nr=250
- y = np.random.randn(Nr)
- gr = np.random.choice(list("ABCDE"), Nr)
- norm_params=[(0, 1.2), (0.7, 1), (-0.5, 1.4), (0.3, 1), (0.8, 0.9)]
-
- for i, letter in enumerate("ABCDE"):
- y[gr == letter] *=norm_params[i][1]+ norm_params[i][0]
- df = pd.DataFrame(dict(Score=y, Group=gr))
-
- # define header params
- data_header = 'Score'
- group_header = 'Group'
-
- # make groupby object with pandas
- group_stats = {}
- groupby_data = df.groupby([group_header])
-
- for group in "ABCDE":
- data_from_group = groupby_data.get_group(group)[data_header]
- # take a stat of the grouped data
- stat = np.median(data_from_group)
- # add to dictionary
- group_stats[group] = stat
-
- # create violin fig
- fig = FF.create_violin(df, data_header='Score', group_header='Group',
- height=600, width=1000, use_colorscale=True,
- group_stats=group_stats)
-
- # plot
- py.iplot(fig, filename='Violin Plot with Colorscale')
- ```
- """
- from plotly.graph_objs import graph_objs
- from numbers import Number
-
- # Validate colors
- if isinstance(colors, dict):
- valid_colors = FigureFactory._validate_colors_dict(colors, 'rgb')
- else:
- valid_colors = FigureFactory._validate_colors(colors, 'rgb')
-
- # validate data and choose plot type
- if group_header is None:
- if isinstance(data, list):
- if len(data) <= 0:
- raise exceptions.PlotlyError("If data is a list, it must be "
- "nonempty and contain either "
- "numbers or dictionaries.")
-
- if not all(isinstance(element, Number) for element in data):
- raise exceptions.PlotlyError("If data is a list, it must "
- "contain only numbers.")
-
- if _pandas_imported and isinstance(data, pd.core.frame.DataFrame):
- if data_header is None:
- raise exceptions.PlotlyError("data_header must be the "
- "column name with the "
- "desired numeric data for "
- "the violin plot.")
-
- data = data[data_header].values.tolist()
-
- # call the plotting functions
- plot_data, plot_xrange = FigureFactory._violinplot(
- data, fillcolor=valid_colors[0]
- )
-
- layout = graph_objs.Layout(
- title=title,
- autosize=False,
- font=graph_objs.Font(size=11),
- height=height,
- showlegend=False,
- width=width,
- xaxis=FigureFactory._make_XAxis('', plot_xrange),
- yaxis=FigureFactory._make_YAxis(''),
- hovermode='closest'
- )
- layout['yaxis'].update(dict(showline=False,
- showticklabels=False,
- ticks=''))
-
- fig = graph_objs.Figure(data=graph_objs.Data(plot_data),
- layout=layout)
-
- return fig
-
- else:
- if not isinstance(data, pd.core.frame.DataFrame):
- raise exceptions.PlotlyError("Error. You must use a pandas "
- "DataFrame if you are using a "
- "group header.")
-
- if data_header is None:
- raise exceptions.PlotlyError("data_header must be the column "
- "name with the desired numeric "
- "data for the violin plot.")
-
- if use_colorscale is False:
- if isinstance(valid_colors, dict):
- # validate colors dict choice below
- fig = FigureFactory._violin_dict(
- data, data_header, group_header, valid_colors,
- use_colorscale, group_stats, height, width, title
- )
- return fig
- else:
- fig = FigureFactory._violin_no_colorscale(
- data, data_header, group_header, valid_colors,
- use_colorscale, group_stats, height, width, title
- )
- return fig
- else:
- if isinstance(valid_colors, dict):
- raise exceptions.PlotlyError("The colors param cannot be "
- "a dictionary if you are "
- "using a colorscale.")
-
- if len(valid_colors) < 2:
- raise exceptions.PlotlyError("colors must be a list with "
- "at least 2 colors. A "
- "Plotly scale is allowed.")
-
- if not isinstance(group_stats, dict):
- raise exceptions.PlotlyError("Your group_stats param "
- "must be a dictionary.")
-
- fig = FigureFactory._violin_colorscale(
- data, data_header, group_header, valid_colors,
- use_colorscale, group_stats, height, width, title
- )
- return fig
+ def create_dendrogram(*args, **kwargs):
+ FigureFactory._deprecated('create_dendrogram')
+ from plotly.figure_factory import create_dendrogram
+ return create_dendrogram(*args, **kwargs)
@staticmethod
- def _find_intermediate_color(lowcolor, highcolor, intermed):
- """
- Returns the color at a given distance between two colors
-
- This function takes two color tuples, where each element is between 0
- and 1, along with a value 0 < intermed < 1 and returns a color that is
- intermed-percent from lowcolor to highcolor
-
- """
- diff_0 = float(highcolor[0] - lowcolor[0])
- diff_1 = float(highcolor[1] - lowcolor[1])
- diff_2 = float(highcolor[2] - lowcolor[2])
-
- return (lowcolor[0] + intermed * diff_0,
- lowcolor[1] + intermed * diff_1,
- lowcolor[2] + intermed * diff_2)
+ def create_distplot(*args, **kwargs):
+ FigureFactory._deprecated('create_distplot')
+ from plotly.figure_factory import create_distplot
+ return create_distplot(*args, **kwargs)
@staticmethod
- def _color_parser(colors, function):
- """
- Takes color(s) and a function and applies the function on the color(s)
-
- In particular, this function identifies whether the given color object
- is an iterable or not and applies the given color-parsing function to
- the color or iterable of colors. If given an iterable, it will only be
- able to work with it if all items in the iterable are of the same type
- - rgb string, hex string or tuple
-
- """
- from numbers import Number
- if isinstance(colors, str):
- return function(colors)
-
- if isinstance(colors, tuple) and isinstance(colors[0], Number):
- return function(colors)
-
- if hasattr(colors, '__iter__'):
- if isinstance(colors, tuple):
- new_color_tuple = tuple(function(item) for item in colors)
- return new_color_tuple
-
- else:
- new_color_list = [function(item) for item in colors]
- return new_color_list
+ def create_gantt(*args, **kwargs):
+ FigureFactory._deprecated('create_gantt')
+ from plotly.figure_factory import create_gantt
+ return create_gantt(*args, **kwargs)
@staticmethod
- def _unconvert_from_RGB_255(colors):
- """
- Return a tuple where each element gets divided by 255
-
- Takes a (list of) color tuple(s) where each element is between 0 and
- 255. Returns the same tuples where each tuple element is normalized to
- a value between 0 and 1
-
- """
- return (colors[0]/(255.0),
- colors[1]/(255.0),
- colors[2]/(255.0))
+ def create_ohlc(*args, **kwargs):
+ FigureFactory._deprecated('create_ohlc')
+ from plotly.figure_factory import create_ohlc
+ return create_ohlc(*args, **kwargs)
@staticmethod
- def _map_face2color(face, colormap, scale, vmin, vmax):
- """
- Normalize facecolor values by vmin/vmax and return rgb-color strings
-
- This function takes a tuple color along with a colormap and a minimum
- (vmin) and maximum (vmax) range of possible mean distances for the
- given parametrized surface. It returns an rgb color based on the mean
- distance between vmin and vmax
-
- """
- if vmin >= vmax:
- raise exceptions.PlotlyError("Incorrect relation between vmin "
- "and vmax. The vmin value cannot be "
- "bigger than or equal to the value "
- "of vmax.")
- if len(colormap) == 1:
- # color each triangle face with the same color in colormap
- face_color = colormap[0]
- face_color = colors.convert_to_RGB_255(face_color)
- face_color = colors.label_rgb(face_color)
- return face_color
- if face == vmax:
- # pick last color in colormap
- face_color = colormap[-1]
- face_color = colors.convert_to_RGB_255(face_color)
- face_color = colors.label_rgb(face_color)
- return face_color
- else:
- if scale is None:
- # find the normalized distance t of a triangle face between
- # vmin and vmax where the distance is between 0 and 1
- t = (face - vmin) / float((vmax - vmin))
- low_color_index = int(t / (1./(len(colormap) - 1)))
-
- face_color = colors.find_intermediate_color(
- colormap[low_color_index],
- colormap[low_color_index + 1],
- t * (len(colormap) - 1) - low_color_index
- )
-
- face_color = colors.convert_to_RGB_255(face_color)
- face_color = colors.label_rgb(face_color)
- else:
- # find the face color for a non-linearly interpolated scale
- t = (face - vmin) / float((vmax - vmin))
-
- low_color_index = 0
- for k in range(len(scale) - 1):
- if scale[k] <= t < scale[k+1]:
- break
- low_color_index += 1
-
- low_scale_val = scale[low_color_index]
- high_scale_val = scale[low_color_index + 1]
-
- face_color = colors.find_intermediate_color(
- colormap[low_color_index],
- colormap[low_color_index + 1],
- (t - low_scale_val)/(high_scale_val - low_scale_val)
- )
-
- face_color = colors.convert_to_RGB_255(face_color)
- face_color = colors.label_rgb(face_color)
- return face_color
+ def create_quiver(*args, **kwargs):
+ FigureFactory._deprecated('create_quiver')
+ from plotly.figure_factory import create_quiver
+ return create_quiver(*args, **kwargs)
@staticmethod
- def _trisurf(x, y, z, simplices, show_colorbar, edges_color, scale,
- colormap=None, color_func=None, plot_edges=False,
- x_edge=None, y_edge=None, z_edge=None, facecolor=None):
- """
- Refer to FigureFactory.create_trisurf() for docstring
- """
- # numpy import check
- if _numpy_imported is False:
- raise ImportError("FigureFactory._trisurf() requires "
- "numpy imported.")
- import numpy as np
- from plotly.graph_objs import graph_objs
- points3D = np.vstack((x, y, z)).T
- simplices = np.atleast_2d(simplices)
-
- # vertices of the surface triangles
- tri_vertices = points3D[simplices]
-
- # Define colors for the triangle faces
- if color_func is None:
- # mean values of z-coordinates of triangle vertices
- mean_dists = tri_vertices[:, :, 2].mean(-1)
- elif isinstance(color_func, (list, np.ndarray)):
- # Pre-computed list / array of values to map onto color
- if len(color_func) != len(simplices):
- raise ValueError("If color_func is a list/array, it must "
- "be the same length as simplices.")
-
- # convert all colors in color_func to rgb
- for index in range(len(color_func)):
- if isinstance(color_func[index], str):
- if '#' in color_func[index]:
- foo = colors.hex_to_rgb(color_func[index])
- color_func[index] = colors.label_rgb(foo)
-
- if isinstance(color_func[index], tuple):
- foo = colors.convert_to_RGB_255(color_func[index])
- color_func[index] = colors.label_rgb(foo)
-
- mean_dists = np.asarray(color_func)
- else:
- # apply user inputted function to calculate
- # custom coloring for triangle vertices
- mean_dists = []
- for triangle in tri_vertices:
- dists = []
- for vertex in triangle:
- dist = color_func(vertex[0], vertex[1], vertex[2])
- dists.append(dist)
- mean_dists.append(np.mean(dists))
- mean_dists = np.asarray(mean_dists)
-
- # Check if facecolors are already strings and can be skipped
- if isinstance(mean_dists[0], str):
- facecolor = mean_dists
- else:
- min_mean_dists = np.min(mean_dists)
- max_mean_dists = np.max(mean_dists)
-
- if facecolor is None:
- facecolor = []
- for index in range(len(mean_dists)):
- color = FigureFactory._map_face2color(mean_dists[index],
- colormap,
- scale,
- min_mean_dists,
- max_mean_dists)
- facecolor.append(color)
-
- # Make sure facecolor is a list so output is consistent across Pythons
- facecolor = np.asarray(facecolor)
- ii, jj, kk = simplices.T
-
- triangles = graph_objs.Mesh3d(x=x, y=y, z=z, facecolor=facecolor,
- i=ii, j=jj, k=kk, name='')
-
- mean_dists_are_numbers = not isinstance(mean_dists[0], str)
-
- if mean_dists_are_numbers and show_colorbar is True:
- # make a colorscale from the colors
- colorscale = colors.make_colorscale(colormap, scale)
- colorscale = colors.convert_colorscale_to_rgb(colorscale)
-
- colorbar = graph_objs.Scatter3d(
- x=x[:1],
- y=y[:1],
- z=z[:1],
- mode='markers',
- marker=dict(
- size=0.1,
- color=[min_mean_dists, max_mean_dists],
- colorscale=colorscale,
- showscale=True),
- hoverinfo='None',
- showlegend=False
- )
-
- # the triangle sides are not plotted
- if plot_edges is False:
- if mean_dists_are_numbers and show_colorbar is True:
- return graph_objs.Data([triangles, colorbar])
- else:
- return graph_objs.Data([triangles])
-
- # define the lists x_edge, y_edge and z_edge, of x, y, resp z
- # coordinates of edge end points for each triangle
- # None separates data corresponding to two consecutive triangles
- is_none = [ii is None for ii in [x_edge, y_edge, z_edge]]
- if any(is_none):
- if not all(is_none):
- raise ValueError("If any (x_edge, y_edge, z_edge) is None, "
- "all must be None")
- else:
- x_edge = []
- y_edge = []
- z_edge = []
-
- # Pull indices we care about, then add a None column to separate tris
- ixs_triangles = [0, 1, 2, 0]
- pull_edges = tri_vertices[:, ixs_triangles, :]
- x_edge_pull = np.hstack([pull_edges[:, :, 0],
- np.tile(None, [pull_edges.shape[0], 1])])
- y_edge_pull = np.hstack([pull_edges[:, :, 1],
- np.tile(None, [pull_edges.shape[0], 1])])
- z_edge_pull = np.hstack([pull_edges[:, :, 2],
- np.tile(None, [pull_edges.shape[0], 1])])
-
- # Now unravel the edges into a 1-d vector for plotting
- x_edge = np.hstack([x_edge, x_edge_pull.reshape([1, -1])[0]])
- y_edge = np.hstack([y_edge, y_edge_pull.reshape([1, -1])[0]])
- z_edge = np.hstack([z_edge, z_edge_pull.reshape([1, -1])[0]])
-
- if not (len(x_edge) == len(y_edge) == len(z_edge)):
- raise exceptions.PlotlyError("The lengths of x_edge, y_edge and "
- "z_edge are not the same.")
-
- # define the lines for plotting
- lines = graph_objs.Scatter3d(
- x=x_edge, y=y_edge, z=z_edge, mode='lines',
- line=graph_objs.Line(
- color=edges_color,
- width=1.5
- ),
- showlegend=False
- )
-
- if mean_dists_are_numbers and show_colorbar is True:
- return graph_objs.Data([triangles, lines, colorbar])
- else:
- return graph_objs.Data([triangles, lines])
+ def create_scatterplotmatrix(*args, **kwargs):
+ FigureFactory._deprecated('create_scatterplotmatrix')
+ from plotly.figure_factory import create_scatterplotmatrix
+ return create_scatterplotmatrix(*args, **kwargs)
@staticmethod
- def create_trisurf(x, y, z, simplices, colormap=None, show_colorbar=True,
- scale=None, color_func=None, title='Trisurf Plot',
- plot_edges=True, showbackground=True,
- backgroundcolor='rgb(230, 230, 230)',
- gridcolor='rgb(255, 255, 255)',
- zerolinecolor='rgb(255, 255, 255)',
- edges_color='rgb(50, 50, 50)',
- height=800, width=800,
- aspectratio=dict(x=1, y=1, z=1)):
- """
- Returns figure for a triangulated surface plot
-
- :param (array) x: data values of x in a 1D array
- :param (array) y: data values of y in a 1D array
- :param (array) z: data values of z in a 1D array
- :param (array) simplices: an array of shape (ntri, 3) where ntri is
- the number of triangles in the triangularization. Each row of the
- array contains the indicies of the verticies of each triangle
- :param (str|tuple|list) colormap: either a plotly scale name, an rgb
- or hex color, a color tuple or a list of colors. An rgb color is
- of the form 'rgb(x, y, z)' where x, y, z belong to the interval
- [0, 255] and a color tuple is a tuple of the form (a, b, c) where
- a, b and c belong to [0, 1]. If colormap is a list, it must
- contain the valid color types aforementioned as its members
- :param (bool) show_colorbar: determines if colorbar is visible
- :param (list|array) scale: sets the scale values to be used if a non-
- linearly interpolated colormap is desired. If left as None, a
- linear interpolation between the colors will be excecuted
- :param (function|list) color_func: The parameter that determines the
- coloring of the surface. Takes either a function with 3 arguments
- x, y, z or a list/array of color values the same length as
- simplices. If None, coloring will only depend on the z axis
- :param (str) title: title of the plot
- :param (bool) plot_edges: determines if the triangles on the trisurf
- are visible
- :param (bool) showbackground: makes background in plot visible
- :param (str) backgroundcolor: color of background. Takes a string of
- the form 'rgb(x,y,z)' x,y,z are between 0 and 255 inclusive
- :param (str) gridcolor: color of the gridlines besides the axes. Takes
- a string of the form 'rgb(x,y,z)' x,y,z are between 0 and 255
- inclusive
- :param (str) zerolinecolor: color of the axes. Takes a string of the
- form 'rgb(x,y,z)' x,y,z are between 0 and 255 inclusive
- :param (str) edges_color: color of the edges, if plot_edges is True
- :param (int|float) height: the height of the plot (in pixels)
- :param (int|float) width: the width of the plot (in pixels)
- :param (dict) aspectratio: a dictionary of the aspect ratio values for
- the x, y and z axes. 'x', 'y' and 'z' take (int|float) values
-
- Example 1: Sphere
- ```
- # Necessary Imports for Trisurf
- import numpy as np
- from scipy.spatial import Delaunay
-
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
- from plotly.graph_objs import graph_objs
-
- # Make data for plot
- u = np.linspace(0, 2*np.pi, 20)
- v = np.linspace(0, np.pi, 20)
- u,v = np.meshgrid(u,v)
- u = u.flatten()
- v = v.flatten()
-
- x = np.sin(v)*np.cos(u)
- y = np.sin(v)*np.sin(u)
- z = np.cos(v)
-
- points2D = np.vstack([u,v]).T
- tri = Delaunay(points2D)
- simplices = tri.simplices
-
- # Create a figure
- fig1 = FF.create_trisurf(x=x, y=y, z=z,
- colormap="Rainbow",
- simplices=simplices)
- # Plot the data
- py.iplot(fig1, filename='trisurf-plot-sphere')
- ```
-
- Example 2: Torus
- ```
- # Necessary Imports for Trisurf
- import numpy as np
- from scipy.spatial import Delaunay
-
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
- from plotly.graph_objs import graph_objs
-
- # Make data for plot
- u = np.linspace(0, 2*np.pi, 20)
- v = np.linspace(0, 2*np.pi, 20)
- u,v = np.meshgrid(u,v)
- u = u.flatten()
- v = v.flatten()
-
- x = (3 + (np.cos(v)))*np.cos(u)
- y = (3 + (np.cos(v)))*np.sin(u)
- z = np.sin(v)
-
- points2D = np.vstack([u,v]).T
- tri = Delaunay(points2D)
- simplices = tri.simplices
-
- # Create a figure
- fig1 = FF.create_trisurf(x=x, y=y, z=z,
- colormap="Viridis",
- simplices=simplices)
- # Plot the data
- py.iplot(fig1, filename='trisurf-plot-torus')
- ```
-
- Example 3: Mobius Band
- ```
- # Necessary Imports for Trisurf
- import numpy as np
- from scipy.spatial import Delaunay
-
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
- from plotly.graph_objs import graph_objs
-
- # Make data for plot
- u = np.linspace(0, 2*np.pi, 24)
- v = np.linspace(-1, 1, 8)
- u,v = np.meshgrid(u,v)
- u = u.flatten()
- v = v.flatten()
-
- tp = 1 + 0.5*v*np.cos(u/2.)
- x = tp*np.cos(u)
- y = tp*np.sin(u)
- z = 0.5*v*np.sin(u/2.)
-
- points2D = np.vstack([u,v]).T
- tri = Delaunay(points2D)
- simplices = tri.simplices
-
- # Create a figure
- fig1 = FF.create_trisurf(x=x, y=y, z=z,
- colormap=[(0.2, 0.4, 0.6), (1, 1, 1)],
- simplices=simplices)
- # Plot the data
- py.iplot(fig1, filename='trisurf-plot-mobius-band')
- ```
-
- Example 4: Using a Custom Colormap Function with Light Cone
- ```
- # Necessary Imports for Trisurf
- import numpy as np
- from scipy.spatial import Delaunay
-
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
- from plotly.graph_objs import graph_objs
-
- # Make data for plot
- u=np.linspace(-np.pi, np.pi, 30)
- v=np.linspace(-np.pi, np.pi, 30)
- u,v=np.meshgrid(u,v)
- u=u.flatten()
- v=v.flatten()
-
- x = u
- y = u*np.cos(v)
- z = u*np.sin(v)
-
- points2D = np.vstack([u,v]).T
- tri = Delaunay(points2D)
- simplices = tri.simplices
-
- # Define distance function
- def dist_origin(x, y, z):
- return np.sqrt((1.0 * x)**2 + (1.0 * y)**2 + (1.0 * z)**2)
-
- # Create a figure
- fig1 = FF.create_trisurf(x=x, y=y, z=z,
- colormap=['#FFFFFF', '#E4FFFE',
- '#A4F6F9', '#FF99FE',
- '#BA52ED'],
- scale=[0, 0.6, 0.71, 0.89, 1],
- simplices=simplices,
- color_func=dist_origin)
- # Plot the data
- py.iplot(fig1, filename='trisurf-plot-custom-coloring')
- ```
-
- Example 5: Enter color_func as a list of colors
- ```
- # Necessary Imports for Trisurf
- import numpy as np
- from scipy.spatial import Delaunay
- import random
-
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
- from plotly.graph_objs import graph_objs
-
- # Make data for plot
- u=np.linspace(-np.pi, np.pi, 30)
- v=np.linspace(-np.pi, np.pi, 30)
- u,v=np.meshgrid(u,v)
- u=u.flatten()
- v=v.flatten()
-
- x = u
- y = u*np.cos(v)
- z = u*np.sin(v)
-
- points2D = np.vstack([u,v]).T
- tri = Delaunay(points2D)
- simplices = tri.simplices
-
-
- colors = []
- color_choices = ['rgb(0, 0, 0)', '#6c4774', '#d6c7dd']
-
- for index in range(len(simplices)):
- colors.append(random.choice(color_choices))
-
- fig = FF.create_trisurf(
- x, y, z, simplices,
- color_func=colors,
- show_colorbar=True,
- edges_color='rgb(2, 85, 180)',
- title=' Modern Art'
- )
-
- py.iplot(fig, filename="trisurf-plot-modern-art")
- ```
- """
- from plotly.graph_objs import graph_objs
-
- # Validate colormap
- colors.validate_colors(colormap)
- colormap, scale = colors.convert_colors_to_same_type(
- colormap, colortype='tuple',
- return_default_colors=True, scale=scale
- )
-
- data1 = FigureFactory._trisurf(x, y, z, simplices,
- show_colorbar=show_colorbar,
- color_func=color_func,
- colormap=colormap,
- scale=scale,
- edges_color=edges_color,
- plot_edges=plot_edges)
-
- axis = dict(
- showbackground=showbackground,
- backgroundcolor=backgroundcolor,
- gridcolor=gridcolor,
- zerolinecolor=zerolinecolor,
- )
- layout = graph_objs.Layout(
- title=title,
- width=width,
- height=height,
- scene=graph_objs.Scene(
- xaxis=graph_objs.XAxis(axis),
- yaxis=graph_objs.YAxis(axis),
- zaxis=graph_objs.ZAxis(axis),
- aspectratio=dict(
- x=aspectratio['x'],
- y=aspectratio['y'],
- z=aspectratio['z']),
- )
- )
-
- return graph_objs.Figure(data=data1, layout=layout)
+ def create_streamline(*args, **kwargs):
+ FigureFactory._deprecated('create_streamline')
+ from plotly.figure_factory import create_streamline
+ return create_streamline(*args, **kwargs)
@staticmethod
- def _scatterplot(dataframe, headers, diag, size,
- height, width, title, **kwargs):
- """
- Refer to FigureFactory.create_scatterplotmatrix() for docstring
-
- Returns fig for scatterplotmatrix without index
-
- """
- from plotly.graph_objs import graph_objs
- dim = len(dataframe)
- fig = make_subplots(rows=dim, cols=dim, print_grid=False)
- trace_list = []
- # Insert traces into trace_list
- for listy in dataframe:
- for listx in dataframe:
- if (listx == listy) and (diag == 'histogram'):
- trace = graph_objs.Histogram(
- x=listx,
- showlegend=False
- )
- elif (listx == listy) and (diag == 'box'):
- trace = graph_objs.Box(
- y=listx,
- name=None,
- showlegend=False
- )
- else:
- if 'marker' in kwargs:
- kwargs['marker']['size'] = size
- trace = graph_objs.Scatter(
- x=listx,
- y=listy,
- mode='markers',
- showlegend=False,
- **kwargs
- )
- trace_list.append(trace)
- else:
- trace = graph_objs.Scatter(
- x=listx,
- y=listy,
- mode='markers',
- marker=dict(
- size=size),
- showlegend=False,
- **kwargs
- )
- trace_list.append(trace)
-
- trace_index = 0
- indices = range(1, dim + 1)
- for y_index in indices:
- for x_index in indices:
- fig.append_trace(trace_list[trace_index],
- y_index,
- x_index)
- trace_index += 1
-
- # Insert headers into the figure
- for j in range(dim):
- xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j)
- fig['layout'][xaxis_key].update(title=headers[j])
- for j in range(dim):
- yaxis_key = 'yaxis{}'.format(1 + (dim * j))
- fig['layout'][yaxis_key].update(title=headers[j])
-
- fig['layout'].update(
- height=height, width=width,
- title=title,
- showlegend=True
- )
-
- FigureFactory._hide_tick_labels_from_box_subplots(fig)
-
- return fig
+ def create_table(*args, **kwargs):
+ FigureFactory._deprecated('create_table')
+ from plotly.figure_factory import create_table
+ return create_table(*args, **kwargs)
@staticmethod
- def _scatterplot_dict(dataframe, headers, diag, size,
- height, width, title, index, index_vals,
- endpts, colormap, colormap_type, **kwargs):
- """
- Refer to FigureFactory.create_scatterplotmatrix() for docstring
-
- Returns fig for scatterplotmatrix with both index and colormap picked.
- Used if colormap is a dictionary with index values as keys pointing to
- colors. Forces colormap_type to behave categorically because it would
- not make sense colors are assigned to each index value and thus
- implies that a categorical approach should be taken
-
- """
- from plotly.graph_objs import graph_objs
-
- theme = colormap
- dim = len(dataframe)
- fig = make_subplots(rows=dim, cols=dim, print_grid=False)
- trace_list = []
- legend_param = 0
- # Work over all permutations of list pairs
- for listy in dataframe:
- for listx in dataframe:
- # create a dictionary for index_vals
- unique_index_vals = {}
- for name in index_vals:
- if name not in unique_index_vals:
- unique_index_vals[name] = []
-
- # Fill all the rest of the names into the dictionary
- for name in sorted(unique_index_vals.keys()):
- new_listx = []
- new_listy = []
- for j in range(len(index_vals)):
- if index_vals[j] == name:
- new_listx.append(listx[j])
- new_listy.append(listy[j])
- # Generate trace with VISIBLE icon
- if legend_param == 1:
- if (listx == listy) and (diag == 'histogram'):
- trace = graph_objs.Histogram(
- x=new_listx,
- marker=dict(
- color=theme[name]),
- showlegend=True
- )
- elif (listx == listy) and (diag == 'box'):
- trace = graph_objs.Box(
- y=new_listx,
- name=None,
- marker=dict(
- color=theme[name]),
- showlegend=True
- )
- else:
- if 'marker' in kwargs:
- kwargs['marker']['size'] = size
- kwargs['marker']['color'] = theme[name]
- trace = graph_objs.Scatter(
- x=new_listx,
- y=new_listy,
- mode='markers',
- name=name,
- showlegend=True,
- **kwargs
- )
- else:
- trace = graph_objs.Scatter(
- x=new_listx,
- y=new_listy,
- mode='markers',
- name=name,
- marker=dict(
- size=size,
- color=theme[name]),
- showlegend=True,
- **kwargs
- )
- # Generate trace with INVISIBLE icon
- else:
- if (listx == listy) and (diag == 'histogram'):
- trace = graph_objs.Histogram(
- x=new_listx,
- marker=dict(
- color=theme[name]),
- showlegend=False
- )
- elif (listx == listy) and (diag == 'box'):
- trace = graph_objs.Box(
- y=new_listx,
- name=None,
- marker=dict(
- color=theme[name]),
- showlegend=False
- )
- else:
- if 'marker' in kwargs:
- kwargs['marker']['size'] = size
- kwargs['marker']['color'] = theme[name]
- trace = graph_objs.Scatter(
- x=new_listx,
- y=new_listy,
- mode='markers',
- name=name,
- showlegend=False,
- **kwargs
- )
- else:
- trace = graph_objs.Scatter(
- x=new_listx,
- y=new_listy,
- mode='markers',
- name=name,
- marker=dict(
- size=size,
- color=theme[name]),
- showlegend=False,
- **kwargs
- )
- # Push the trace into dictionary
- unique_index_vals[name] = trace
- trace_list.append(unique_index_vals)
- legend_param += 1
-
- trace_index = 0
- indices = range(1, dim + 1)
- for y_index in indices:
- for x_index in indices:
- for name in sorted(trace_list[trace_index].keys()):
- fig.append_trace(
- trace_list[trace_index][name],
- y_index,
- x_index)
- trace_index += 1
-
- # Insert headers into the figure
- for j in range(dim):
- xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j)
- fig['layout'][xaxis_key].update(title=headers[j])
-
- for j in range(dim):
- yaxis_key = 'yaxis{}'.format(1 + (dim * j))
- fig['layout'][yaxis_key].update(title=headers[j])
-
- FigureFactory._hide_tick_labels_from_box_subplots(fig)
-
- if diag == 'histogram':
- fig['layout'].update(
- height=height, width=width,
- title=title,
- showlegend=True,
- barmode='stack')
- return fig
-
- else:
- fig['layout'].update(
- height=height, width=width,
- title=title,
- showlegend=True)
- return fig
+ def create_trisurf(*args, **kwargs):
+ FigureFactory._deprecated('create_trisurf')
+ from plotly.figure_factory import create_trisurf
+ return create_trisurf(*args, **kwargs)
@staticmethod
- def _scatterplot_theme(dataframe, headers, diag, size, height,
- width, title, index, index_vals, endpts,
- colormap, colormap_type, **kwargs):
- """
- Refer to FigureFactory.create_scatterplotmatrix() for docstring
-
- Returns fig for scatterplotmatrix with both index and colormap picked
-
- """
- from plotly.graph_objs import graph_objs
-
- # Check if index is made of string values
- if isinstance(index_vals[0], str):
- unique_index_vals = []
- for name in index_vals:
- if name not in unique_index_vals:
- unique_index_vals.append(name)
- n_colors_len = len(unique_index_vals)
-
- # Convert colormap to list of n RGB tuples
- if colormap_type == 'seq':
- foo = FigureFactory._color_parser(
- colormap, FigureFactory._unlabel_rgb
- )
- foo = FigureFactory._n_colors(foo[0],
- foo[1],
- n_colors_len)
- theme = FigureFactory._color_parser(
- foo, FigureFactory._label_rgb
- )
-
- if colormap_type == 'cat':
- # leave list of colors the same way
- theme = colormap
-
- dim = len(dataframe)
- fig = make_subplots(rows=dim, cols=dim, print_grid=False)
- trace_list = []
- legend_param = 0
- # Work over all permutations of list pairs
- for listy in dataframe:
- for listx in dataframe:
- # create a dictionary for index_vals
- unique_index_vals = {}
- for name in index_vals:
- if name not in unique_index_vals:
- unique_index_vals[name] = []
-
- c_indx = 0 # color index
- # Fill all the rest of the names into the dictionary
- for name in sorted(unique_index_vals.keys()):
- new_listx = []
- new_listy = []
- for j in range(len(index_vals)):
- if index_vals[j] == name:
- new_listx.append(listx[j])
- new_listy.append(listy[j])
- # Generate trace with VISIBLE icon
- if legend_param == 1:
- if (listx == listy) and (diag == 'histogram'):
- trace = graph_objs.Histogram(
- x=new_listx,
- marker=dict(
- color=theme[c_indx]),
- showlegend=True
- )
- elif (listx == listy) and (diag == 'box'):
- trace = graph_objs.Box(
- y=new_listx,
- name=None,
- marker=dict(
- color=theme[c_indx]),
- showlegend=True
- )
- else:
- if 'marker' in kwargs:
- kwargs['marker']['size'] = size
- kwargs['marker']['color'] = theme[c_indx]
- trace = graph_objs.Scatter(
- x=new_listx,
- y=new_listy,
- mode='markers',
- name=name,
- showlegend=True,
- **kwargs
- )
- else:
- trace = graph_objs.Scatter(
- x=new_listx,
- y=new_listy,
- mode='markers',
- name=name,
- marker=dict(
- size=size,
- color=theme[c_indx]),
- showlegend=True,
- **kwargs
- )
- # Generate trace with INVISIBLE icon
- else:
- if (listx == listy) and (diag == 'histogram'):
- trace = graph_objs.Histogram(
- x=new_listx,
- marker=dict(
- color=theme[c_indx]),
- showlegend=False
- )
- elif (listx == listy) and (diag == 'box'):
- trace = graph_objs.Box(
- y=new_listx,
- name=None,
- marker=dict(
- color=theme[c_indx]),
- showlegend=False
- )
- else:
- if 'marker' in kwargs:
- kwargs['marker']['size'] = size
- kwargs['marker']['color'] = theme[c_indx]
- trace = graph_objs.Scatter(
- x=new_listx,
- y=new_listy,
- mode='markers',
- name=name,
- showlegend=False,
- **kwargs
- )
- else:
- trace = graph_objs.Scatter(
- x=new_listx,
- y=new_listy,
- mode='markers',
- name=name,
- marker=dict(
- size=size,
- color=theme[c_indx]),
- showlegend=False,
- **kwargs
- )
- # Push the trace into dictionary
- unique_index_vals[name] = trace
- if c_indx >= (len(theme) - 1):
- c_indx = -1
- c_indx += 1
- trace_list.append(unique_index_vals)
- legend_param += 1
-
- trace_index = 0
- indices = range(1, dim + 1)
- for y_index in indices:
- for x_index in indices:
- for name in sorted(trace_list[trace_index].keys()):
- fig.append_trace(
- trace_list[trace_index][name],
- y_index,
- x_index)
- trace_index += 1
-
- # Insert headers into the figure
- for j in range(dim):
- xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j)
- fig['layout'][xaxis_key].update(title=headers[j])
-
- for j in range(dim):
- yaxis_key = 'yaxis{}'.format(1 + (dim * j))
- fig['layout'][yaxis_key].update(title=headers[j])
-
- FigureFactory._hide_tick_labels_from_box_subplots(fig)
-
- if diag == 'histogram':
- fig['layout'].update(
- height=height, width=width,
- title=title,
- showlegend=True,
- barmode='stack')
- return fig
-
- elif diag == 'box':
- fig['layout'].update(
- height=height, width=width,
- title=title,
- showlegend=True)
- return fig
-
- else:
- fig['layout'].update(
- height=height, width=width,
- title=title,
- showlegend=True)
- return fig
-
- else:
- if endpts:
- intervals = FigureFactory._endpts_to_intervals(endpts)
-
- # Convert colormap to list of n RGB tuples
- if colormap_type == 'seq':
- foo = FigureFactory._color_parser(
- colormap, FigureFactory._unlabel_rgb
- )
- foo = FigureFactory._n_colors(foo[0],
- foo[1],
- len(intervals))
- theme = FigureFactory._color_parser(
- foo, FigureFactory._label_rgb
- )
-
- if colormap_type == 'cat':
- # leave list of colors the same way
- theme = colormap
-
- dim = len(dataframe)
- fig = make_subplots(rows=dim, cols=dim, print_grid=False)
- trace_list = []
- legend_param = 0
- # Work over all permutations of list pairs
- for listy in dataframe:
- for listx in dataframe:
- interval_labels = {}
- for interval in intervals:
- interval_labels[str(interval)] = []
-
- c_indx = 0 # color index
- # Fill all the rest of the names into the dictionary
- for interval in intervals:
- new_listx = []
- new_listy = []
- for j in range(len(index_vals)):
- if interval[0] < index_vals[j] <= interval[1]:
- new_listx.append(listx[j])
- new_listy.append(listy[j])
- # Generate trace with VISIBLE icon
- if legend_param == 1:
- if (listx == listy) and (diag == 'histogram'):
- trace = graph_objs.Histogram(
- x=new_listx,
- marker=dict(
- color=theme[c_indx]),
- showlegend=True
- )
- elif (listx == listy) and (diag == 'box'):
- trace = graph_objs.Box(
- y=new_listx,
- name=None,
- marker=dict(
- color=theme[c_indx]),
- showlegend=True
- )
- else:
- if 'marker' in kwargs:
- kwargs['marker']['size'] = size
- (kwargs['marker']
- ['color']) = theme[c_indx]
- trace = graph_objs.Scatter(
- x=new_listx,
- y=new_listy,
- mode='markers',
- name=str(interval),
- showlegend=True,
- **kwargs
- )
- else:
- trace = graph_objs.Scatter(
- x=new_listx,
- y=new_listy,
- mode='markers',
- name=str(interval),
- marker=dict(
- size=size,
- color=theme[c_indx]),
- showlegend=True,
- **kwargs
- )
- # Generate trace with INVISIBLE icon
- else:
- if (listx == listy) and (diag == 'histogram'):
- trace = graph_objs.Histogram(
- x=new_listx,
- marker=dict(
- color=theme[c_indx]),
- showlegend=False
- )
- elif (listx == listy) and (diag == 'box'):
- trace = graph_objs.Box(
- y=new_listx,
- name=None,
- marker=dict(
- color=theme[c_indx]),
- showlegend=False
- )
- else:
- if 'marker' in kwargs:
- kwargs['marker']['size'] = size
- (kwargs['marker']
- ['color']) = theme[c_indx]
- trace = graph_objs.Scatter(
- x=new_listx,
- y=new_listy,
- mode='markers',
- name=str(interval),
- showlegend=False,
- **kwargs
- )
- else:
- trace = graph_objs.Scatter(
- x=new_listx,
- y=new_listy,
- mode='markers',
- name=str(interval),
- marker=dict(
- size=size,
- color=theme[c_indx]),
- showlegend=False,
- **kwargs
- )
- # Push the trace into dictionary
- interval_labels[str(interval)] = trace
- if c_indx >= (len(theme) - 1):
- c_indx = -1
- c_indx += 1
- trace_list.append(interval_labels)
- legend_param += 1
-
- trace_index = 0
- indices = range(1, dim + 1)
- for y_index in indices:
- for x_index in indices:
- for interval in intervals:
- fig.append_trace(
- trace_list[trace_index][str(interval)],
- y_index,
- x_index)
- trace_index += 1
-
- # Insert headers into the figure
- for j in range(dim):
- xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j)
- fig['layout'][xaxis_key].update(title=headers[j])
- for j in range(dim):
- yaxis_key = 'yaxis{}'.format(1 + (dim * j))
- fig['layout'][yaxis_key].update(title=headers[j])
-
- FigureFactory._hide_tick_labels_from_box_subplots(fig)
-
- if diag == 'histogram':
- fig['layout'].update(
- height=height, width=width,
- title=title,
- showlegend=True,
- barmode='stack')
- return fig
-
- elif diag == 'box':
- fig['layout'].update(
- height=height, width=width,
- title=title,
- showlegend=True)
- return fig
-
- else:
- fig['layout'].update(
- height=height, width=width,
- title=title,
- showlegend=True)
- return fig
-
- else:
- theme = colormap
-
- # add a copy of rgb color to theme if it contains one color
- if len(theme) <= 1:
- theme.append(theme[0])
-
- color = []
- for incr in range(len(theme)):
- color.append([1./(len(theme)-1)*incr, theme[incr]])
-
- dim = len(dataframe)
- fig = make_subplots(rows=dim, cols=dim, print_grid=False)
- trace_list = []
- legend_param = 0
- # Run through all permutations of list pairs
- for listy in dataframe:
- for listx in dataframe:
- # Generate trace with VISIBLE icon
- if legend_param == 1:
- if (listx == listy) and (diag == 'histogram'):
- trace = graph_objs.Histogram(
- x=listx,
- marker=dict(
- color=theme[0]),
- showlegend=False
- )
- elif (listx == listy) and (diag == 'box'):
- trace = graph_objs.Box(
- y=listx,
- marker=dict(
- color=theme[0]),
- showlegend=False
- )
- else:
- if 'marker' in kwargs:
- kwargs['marker']['size'] = size
- kwargs['marker']['color'] = index_vals
- kwargs['marker']['colorscale'] = color
- kwargs['marker']['showscale'] = True
- trace = graph_objs.Scatter(
- x=listx,
- y=listy,
- mode='markers',
- showlegend=False,
- **kwargs
- )
- else:
- trace = graph_objs.Scatter(
- x=listx,
- y=listy,
- mode='markers',
- marker=dict(
- size=size,
- color=index_vals,
- colorscale=color,
- showscale=True),
- showlegend=False,
- **kwargs
- )
- # Generate trace with INVISIBLE icon
- else:
- if (listx == listy) and (diag == 'histogram'):
- trace = graph_objs.Histogram(
- x=listx,
- marker=dict(
- color=theme[0]),
- showlegend=False
- )
- elif (listx == listy) and (diag == 'box'):
- trace = graph_objs.Box(
- y=listx,
- marker=dict(
- color=theme[0]),
- showlegend=False
- )
- else:
- if 'marker' in kwargs:
- kwargs['marker']['size'] = size
- kwargs['marker']['color'] = index_vals
- kwargs['marker']['colorscale'] = color
- kwargs['marker']['showscale'] = False
- trace = graph_objs.Scatter(
- x=listx,
- y=listy,
- mode='markers',
- showlegend=False,
- **kwargs
- )
- else:
- trace = graph_objs.Scatter(
- x=listx,
- y=listy,
- mode='markers',
- marker=dict(
- size=size,
- color=index_vals,
- colorscale=color,
- showscale=False),
- showlegend=False,
- **kwargs
- )
- # Push the trace into list
- trace_list.append(trace)
- legend_param += 1
-
- trace_index = 0
- indices = range(1, dim + 1)
- for y_index in indices:
- for x_index in indices:
- fig.append_trace(trace_list[trace_index],
- y_index,
- x_index)
- trace_index += 1
-
- # Insert headers into the figure
- for j in range(dim):
- xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j)
- fig['layout'][xaxis_key].update(title=headers[j])
- for j in range(dim):
- yaxis_key = 'yaxis{}'.format(1 + (dim * j))
- fig['layout'][yaxis_key].update(title=headers[j])
-
- FigureFactory._hide_tick_labels_from_box_subplots(fig)
-
- if diag == 'histogram':
- fig['layout'].update(
- height=height, width=width,
- title=title,
- showlegend=True,
- barmode='stack')
- return fig
-
- elif diag == 'box':
- fig['layout'].update(
- height=height, width=width,
- title=title,
- showlegend=True)
- return fig
-
- else:
- fig['layout'].update(
- height=height, width=width,
- title=title,
- showlegend=True)
- return fig
-
- @staticmethod
- def _hide_tick_labels_from_box_subplots(fig):
- """
- Hides tick labels for box plots in scatterplotmatrix subplots.
- """
- boxplot_xaxes = []
- for trace in fig['data']:
- if trace['type'] == 'box':
- # stores the xaxes which correspond to boxplot subplots
- # since we use xaxis1, xaxis2, etc, in plotly.py
- boxplot_xaxes.append(
- 'xaxis{}'.format(trace['xaxis'][1:])
- )
- for xaxis in boxplot_xaxes:
- fig['layout'][xaxis]['showticklabels'] = False
-
- @staticmethod
- def _validate_index(index_vals):
- """
- Validates if a list contains all numbers or all strings
-
- :raises: (PlotlyError) If there are any two items in the list whose
- types differ
- """
- from numbers import Number
- if isinstance(index_vals[0], Number):
- if not all(isinstance(item, Number) for item in index_vals):
- raise exceptions.PlotlyError("Error in indexing column. "
- "Make sure all entries of each "
- "column are all numbers or "
- "all strings.")
-
- elif isinstance(index_vals[0], str):
- if not all(isinstance(item, str) for item in index_vals):
- raise exceptions.PlotlyError("Error in indexing column. "
- "Make sure all entries of each "
- "column are all numbers or "
- "all strings.")
-
- @staticmethod
- def _validate_dataframe(array):
- """
- Validates all strings or numbers in each dataframe column
-
- :raises: (PlotlyError) If there are any two items in any list whose
- types differ
- """
- from numbers import Number
- for vector in array:
- if isinstance(vector[0], Number):
- if not all(isinstance(item, Number) for item in vector):
- raise exceptions.PlotlyError("Error in dataframe. "
- "Make sure all entries of "
- "each column are either "
- "numbers or strings.")
- elif isinstance(vector[0], str):
- if not all(isinstance(item, str) for item in vector):
- raise exceptions.PlotlyError("Error in dataframe. "
- "Make sure all entries of "
- "each column are either "
- "numbers or strings.")
-
- @staticmethod
- def _validate_scatterplotmatrix(df, index, diag, colormap_type, **kwargs):
- """
- Validates basic inputs for FigureFactory.create_scatterplotmatrix()
-
- :raises: (PlotlyError) If pandas is not imported
- :raises: (PlotlyError) If pandas dataframe is not inputted
- :raises: (PlotlyError) If pandas dataframe has <= 1 columns
- :raises: (PlotlyError) If diagonal plot choice (diag) is not one of
- the viable options
- :raises: (PlotlyError) If colormap_type is not a valid choice
- :raises: (PlotlyError) If kwargs contains 'size', 'color' or
- 'colorscale'
- """
- if _pandas_imported is False:
- raise ImportError("FigureFactory.scatterplotmatrix requires "
- "a pandas DataFrame.")
-
- # Check if pandas dataframe
- if not isinstance(df, pd.core.frame.DataFrame):
- raise exceptions.PlotlyError("Dataframe not inputed. Please "
- "use a pandas dataframe to pro"
- "duce a scatterplot matrix.")
-
- # Check if dataframe is 1 column or less
- if len(df.columns) <= 1:
- raise exceptions.PlotlyError("Dataframe has only one column. To "
- "use the scatterplot matrix, use at "
- "least 2 columns.")
-
- # Check that diag parameter is a valid selection
- if diag not in DIAG_CHOICES:
- raise exceptions.PlotlyError("Make sure diag is set to "
- "one of {}".format(DIAG_CHOICES))
-
- # Check that colormap_types is a valid selection
- if colormap_type not in VALID_COLORMAP_TYPES:
- raise exceptions.PlotlyError("Must choose a valid colormap type. "
- "Either 'cat' or 'seq' for a cate"
- "gorical and sequential colormap "
- "respectively.")
-
- # Check for not 'size' or 'color' in 'marker' of **kwargs
- if 'marker' in kwargs:
- FORBIDDEN_PARAMS = ['size', 'color', 'colorscale']
- if any(param in kwargs['marker'] for param in FORBIDDEN_PARAMS):
- raise exceptions.PlotlyError("Your kwargs dictionary cannot "
- "include the 'size', 'color' or "
- "'colorscale' key words inside "
- "the marker dict since 'size' is "
- "already an argument of the "
- "scatterplot matrix function and "
- "both 'color' and 'colorscale "
- "are set internally.")
-
- @staticmethod
- def _endpts_to_intervals(endpts):
- """
- Returns a list of intervals for categorical colormaps
-
- Accepts a list or tuple of sequentially increasing numbers and returns
- a list representation of the mathematical intervals with these numbers
- as endpoints. For example, [1, 6] returns [[-inf, 1], [1, 6], [6, inf]]
-
- :raises: (PlotlyError) If input is not a list or tuple
- :raises: (PlotlyError) If the input contains a string
- :raises: (PlotlyError) If any number does not increase after the
- previous one in the sequence
- """
- length = len(endpts)
- # Check if endpts is a list or tuple
- if not (isinstance(endpts, (tuple)) or isinstance(endpts, (list))):
- raise exceptions.PlotlyError("The intervals_endpts argument must "
- "be a list or tuple of a sequence "
- "of increasing numbers.")
- # Check if endpts contains only numbers
- for item in endpts:
- if isinstance(item, str):
- raise exceptions.PlotlyError("The intervals_endpts argument "
- "must be a list or tuple of a "
- "sequence of increasing "
- "numbers.")
- # Check if numbers in endpts are increasing
- for k in range(length-1):
- if endpts[k] >= endpts[k+1]:
- raise exceptions.PlotlyError("The intervals_endpts argument "
- "must be a list or tuple of a "
- "sequence of increasing "
- "numbers.")
- else:
- intervals = []
- # add -inf to intervals
- intervals.append([float('-inf'), endpts[0]])
- for k in range(length - 1):
- interval = []
- interval.append(endpts[k])
- interval.append(endpts[k + 1])
- intervals.append(interval)
- # add +inf to intervals
- intervals.append([endpts[length - 1], float('inf')])
- return intervals
-
- @staticmethod
- def _convert_to_RGB_255(colors):
- """
- Multiplies each element of a triplet by 255
-
- Each coordinate of the color tuple is rounded to the nearest float and
- then is turned into an integer. If a number is of the form x.5, then
- if x is odd, the number rounds up to (x+1). Otherwise, it rounds down
- to just x. This is the way rounding works in Python 3 and in current
- statistical analysis to avoid rounding bias
- """
- rgb_components = []
-
- for component in colors:
- rounded_num = decimal.Decimal(str(component*255.0)).quantize(
- decimal.Decimal('1'), rounding=decimal.ROUND_HALF_EVEN
- )
- # convert rounded number to an integer from 'Decimal' form
- rounded_num = int(rounded_num)
- rgb_components.append(rounded_num)
-
- return (rgb_components[0], rgb_components[1], rgb_components[2])
-
- @staticmethod
- def _n_colors(lowcolor, highcolor, n_colors):
- """
- Splits a low and high color into a list of n_colors colors in it
-
- Accepts two color tuples and returns a list of n_colors colors
- which form the intermediate colors between lowcolor and highcolor
- from linearly interpolating through RGB space
-
- """
- diff_0 = float(highcolor[0] - lowcolor[0])
- incr_0 = diff_0/(n_colors - 1)
- diff_1 = float(highcolor[1] - lowcolor[1])
- incr_1 = diff_1/(n_colors - 1)
- diff_2 = float(highcolor[2] - lowcolor[2])
- incr_2 = diff_2/(n_colors - 1)
- color_tuples = []
-
- for index in range(n_colors):
- new_tuple = (lowcolor[0] + (index * incr_0),
- lowcolor[1] + (index * incr_1),
- lowcolor[2] + (index * incr_2))
- color_tuples.append(new_tuple)
-
- return color_tuples
-
- @staticmethod
- def _label_rgb(colors):
- """
- Takes tuple (a, b, c) and returns an rgb color 'rgb(a, b, c)'
- """
- return ('rgb(%s, %s, %s)' % (colors[0], colors[1], colors[2]))
-
- @staticmethod
- def _unlabel_rgb(colors):
- """
- Takes rgb color(s) 'rgb(a, b, c)' and returns tuple(s) (a, b, c)
-
- This function takes either an 'rgb(a, b, c)' color or a list of
- such colors and returns the color tuples in tuple(s) (a, b, c)
-
- """
- str_vals = ''
- for index in range(len(colors)):
- try:
- float(colors[index])
- str_vals = str_vals + colors[index]
- except ValueError:
- if colors[index] == ',' or colors[index] == '.':
- str_vals = str_vals + colors[index]
-
- str_vals = str_vals + ','
- numbers = []
- str_num = ''
- for char in str_vals:
- if char != ',':
- str_num = str_num + char
- else:
- numbers.append(float(str_num))
- str_num = ''
- return (numbers[0], numbers[1], numbers[2])
-
- @staticmethod
- def create_scatterplotmatrix(df, index=None, endpts=None, diag='scatter',
- height=500, width=500, size=6,
- title='Scatterplot Matrix', colormap=None,
- colormap_type='cat', dataframe=None,
- headers=None, index_vals=None, **kwargs):
- """
- Returns data for a scatterplot matrix.
-
- :param (array) df: array of the data with column headers
- :param (str) index: name of the index column in data array
- :param (list|tuple) endpts: takes an increasing sequece of numbers
- that defines intervals on the real line. They are used to group
- the entries in an index of numbers into their corresponding
- interval and therefore can be treated as categorical data
- :param (str) diag: sets the chart type for the main diagonal plots.
- The options are 'scatter', 'histogram' and 'box'.
- :param (int|float) height: sets the height of the chart
- :param (int|float) width: sets the width of the chart
- :param (float) size: sets the marker size (in px)
- :param (str) title: the title label of the scatterplot matrix
- :param (str|tuple|list|dict) colormap: either a plotly scale name,
- an rgb or hex color, a color tuple, a list of colors or a
- dictionary. An rgb color is of the form 'rgb(x, y, z)' where
- x, y and z belong to the interval [0, 255] and a color tuple is a
- tuple of the form (a, b, c) where a, b and c belong to [0, 1].
- If colormap is a list, it must contain valid color types as its
- members.
- If colormap is a dictionary, all the string entries in
- the index column must be a key in colormap. In this case, the
- colormap_type is forced to 'cat' or categorical
- :param (str) colormap_type: determines how colormap is interpreted.
- Valid choices are 'seq' (sequential) and 'cat' (categorical). If
- 'seq' is selected, only the first two colors in colormap will be
- considered (when colormap is a list) and the index values will be
- linearly interpolated between those two colors. This option is
- forced if all index values are numeric.
- If 'cat' is selected, a color from colormap will be assigned to
- each category from index, including the intervals if endpts is
- being used
- :param (dict) **kwargs: a dictionary of scatterplot arguments
- The only forbidden parameters are 'size', 'color' and
- 'colorscale' in 'marker'
-
- Example 1: Vanilla Scatterplot Matrix
- ```
- import plotly.plotly as py
- from plotly.graph_objs import graph_objs
- from plotly.tools import FigureFactory as FF
-
- import numpy as np
- import pandas as pd
-
- # Create dataframe
- df = pd.DataFrame(np.random.randn(10, 2),
- columns=['Column 1', 'Column 2'])
-
- # Create scatterplot matrix
- fig = FF.create_scatterplotmatrix(df)
-
- # Plot
- py.iplot(fig, filename='Vanilla Scatterplot Matrix')
- ```
-
- Example 2: Indexing a Column
- ```
- import plotly.plotly as py
- from plotly.graph_objs import graph_objs
- from plotly.tools import FigureFactory as FF
-
- import numpy as np
- import pandas as pd
-
- # Create dataframe with index
- df = pd.DataFrame(np.random.randn(10, 2),
- columns=['A', 'B'])
-
- # Add another column of strings to the dataframe
- df['Fruit'] = pd.Series(['apple', 'apple', 'grape', 'apple', 'apple',
- 'grape', 'pear', 'pear', 'apple', 'pear'])
-
- # Create scatterplot matrix
- fig = FF.create_scatterplotmatrix(df, index='Fruit', size=10)
-
- # Plot
- py.iplot(fig, filename = 'Scatterplot Matrix with Index')
- ```
-
- Example 3: Styling the Diagonal Subplots
- ```
- import plotly.plotly as py
- from plotly.graph_objs import graph_objs
- from plotly.tools import FigureFactory as FF
-
- import numpy as np
- import pandas as pd
-
- # Create dataframe with index
- df = pd.DataFrame(np.random.randn(10, 4),
- columns=['A', 'B', 'C', 'D'])
-
- # Add another column of strings to the dataframe
- df['Fruit'] = pd.Series(['apple', 'apple', 'grape', 'apple', 'apple',
- 'grape', 'pear', 'pear', 'apple', 'pear'])
-
- # Create scatterplot matrix
- fig = FF.create_scatterplotmatrix(df, diag='box', index='Fruit',
- height=1000, width=1000)
-
- # Plot
- py.iplot(fig, filename = 'Scatterplot Matrix - Diagonal Styling')
- ```
-
- Example 4: Use a Theme to Style the Subplots
- ```
- import plotly.plotly as py
- from plotly.graph_objs import graph_objs
- from plotly.tools import FigureFactory as FF
-
- import numpy as np
- import pandas as pd
-
- # Create dataframe with random data
- df = pd.DataFrame(np.random.randn(100, 3),
- columns=['A', 'B', 'C'])
-
- # Create scatterplot matrix using a built-in
- # Plotly palette scale and indexing column 'A'
- fig = FF.create_scatterplotmatrix(df, diag='histogram',
- index='A', colormap='Blues',
- height=800, width=800)
-
- # Plot
- py.iplot(fig, filename = 'Scatterplot Matrix - Colormap Theme')
- ```
-
- Example 5: Example 4 with Interval Factoring
- ```
- import plotly.plotly as py
- from plotly.graph_objs import graph_objs
- from plotly.tools import FigureFactory as FF
-
- import numpy as np
- import pandas as pd
-
- # Create dataframe with random data
- df = pd.DataFrame(np.random.randn(100, 3),
- columns=['A', 'B', 'C'])
-
- # Create scatterplot matrix using a list of 2 rgb tuples
- # and endpoints at -1, 0 and 1
- fig = FF.create_scatterplotmatrix(df, diag='histogram', index='A',
- colormap=['rgb(140, 255, 50)',
- 'rgb(170, 60, 115)',
- '#6c4774',
- (0.5, 0.1, 0.8)],
- endpts=[-1, 0, 1],
- height=800, width=800)
-
- # Plot
- py.iplot(fig, filename = 'Scatterplot Matrix - Intervals')
- ```
-
- Example 6: Using the colormap as a Dictionary
- ```
- import plotly.plotly as py
- from plotly.graph_objs import graph_objs
- from plotly.tools import FigureFactory as FF
-
- import numpy as np
- import pandas as pd
- import random
-
- # Create dataframe with random data
- df = pd.DataFrame(np.random.randn(100, 3),
- columns=['Column A',
- 'Column B',
- 'Column C'])
-
- # Add new color column to dataframe
- new_column = []
- strange_colors = ['turquoise', 'limegreen', 'goldenrod']
-
- for j in range(100):
- new_column.append(random.choice(strange_colors))
- df['Colors'] = pd.Series(new_column, index=df.index)
-
- # Create scatterplot matrix using a dictionary of hex color values
- # which correspond to actual color names in 'Colors' column
- fig = FF.create_scatterplotmatrix(
- df, diag='box', index='Colors',
- colormap= dict(
- turquoise = '#00F5FF',
- limegreen = '#32CD32',
- goldenrod = '#DAA520'
- ),
- colormap_type='cat',
- height=800, width=800
- )
-
- # Plot
- py.iplot(fig, filename = 'Scatterplot Matrix - colormap dictionary ')
- ```
- """
- # TODO: protected until #282
- if dataframe is None:
- dataframe = []
- if headers is None:
- headers = []
- if index_vals is None:
- index_vals = []
-
- FigureFactory._validate_scatterplotmatrix(df, index, diag,
- colormap_type, **kwargs)
-
- # Validate colormap
- if isinstance(colormap, dict):
- colormap = FigureFactory._validate_colors_dict(colormap, 'rgb')
- else:
- colormap = FigureFactory._validate_colors(colormap, 'rgb')
-
- if not index:
- for name in df:
- headers.append(name)
- for name in headers:
- dataframe.append(df[name].values.tolist())
- # Check for same data-type in df columns
- FigureFactory._validate_dataframe(dataframe)
- figure = FigureFactory._scatterplot(dataframe, headers, diag,
- size, height, width, title,
- **kwargs)
- return figure
- else:
- # Validate index selection
- if index not in df:
- raise exceptions.PlotlyError("Make sure you set the index "
- "input variable to one of the "
- "column names of your "
- "dataframe.")
- index_vals = df[index].values.tolist()
- for name in df:
- if name != index:
- headers.append(name)
- for name in headers:
- dataframe.append(df[name].values.tolist())
-
- # check for same data-type in each df column
- FigureFactory._validate_dataframe(dataframe)
- FigureFactory._validate_index(index_vals)
-
- # check if all colormap keys are in the index
- # if colormap is a dictionary
- if isinstance(colormap, dict):
- for key in colormap:
- if not all(index in colormap for index in index_vals):
- raise exceptions.PlotlyError("If colormap is a "
- "dictionary, all the "
- "names in the index "
- "must be keys.")
- figure = FigureFactory._scatterplot_dict(
- dataframe, headers, diag, size, height, width, title,
- index, index_vals, endpts, colormap, colormap_type,
- **kwargs
- )
- return figure
-
- else:
- figure = FigureFactory._scatterplot_theme(
- dataframe, headers, diag, size, height, width, title,
- index, index_vals, endpts, colormap, colormap_type,
- **kwargs
- )
- return figure
-
- @staticmethod
- def _validate_equal_length(*args):
- """
- Validates that data lists or ndarrays are the same length.
-
- :raises: (PlotlyError) If any data lists are not the same length.
- """
- length = len(args[0])
- if any(len(lst) != length for lst in args):
- raise exceptions.PlotlyError("Oops! Your data lists or ndarrays "
- "should be the same length.")
-
- @staticmethod
- def _validate_ohlc(open, high, low, close, direction, **kwargs):
- """
- ohlc and candlestick specific validations
-
- Specifically, this checks that the high value is the greatest value and
- the low value is the lowest value in each unit.
-
- See FigureFactory.create_ohlc() or FigureFactory.create_candlestick()
- for params
-
- :raises: (PlotlyError) If the high value is not the greatest value in
- each unit.
- :raises: (PlotlyError) If the low value is not the lowest value in each
- unit.
- :raises: (PlotlyError) If direction is not 'increasing' or 'decreasing'
- """
- for lst in [open, low, close]:
- for index in range(len(high)):
- if high[index] < lst[index]:
- raise exceptions.PlotlyError("Oops! Looks like some of "
- "your high values are less "
- "the corresponding open, "
- "low, or close values. "
- "Double check that your data "
- "is entered in O-H-L-C order")
-
- for lst in [open, high, close]:
- for index in range(len(low)):
- if low[index] > lst[index]:
- raise exceptions.PlotlyError("Oops! Looks like some of "
- "your low values are greater "
- "than the corresponding high"
- ", open, or close values. "
- "Double check that your data "
- "is entered in O-H-L-C order")
-
- direction_opts = ('increasing', 'decreasing', 'both')
- if direction not in direction_opts:
- raise exceptions.PlotlyError("direction must be defined as "
- "'increasing', 'decreasing', or "
- "'both'")
-
- @staticmethod
- def _validate_distplot(hist_data, curve_type):
- """
- Distplot-specific validations
-
- :raises: (PlotlyError) If hist_data is not a list of lists
- :raises: (PlotlyError) If curve_type is not valid (i.e. not 'kde' or
- 'normal').
- """
- try:
- import pandas as pd
- _pandas_imported = True
- except ImportError:
- _pandas_imported = False
-
- hist_data_types = (list,)
- if _numpy_imported:
- hist_data_types += (np.ndarray,)
- if _pandas_imported:
- hist_data_types += (pd.core.series.Series,)
-
- if not isinstance(hist_data[0], hist_data_types):
- raise exceptions.PlotlyError("Oops, this function was written "
- "to handle multiple datasets, if "
- "you want to plot just one, make "
- "sure your hist_data variable is "
- "still a list of lists, i.e. x = "
- "[1, 2, 3] -> x = [[1, 2, 3]]")
-
- curve_opts = ('kde', 'normal')
- if curve_type not in curve_opts:
- raise exceptions.PlotlyError("curve_type must be defined as "
- "'kde' or 'normal'")
-
- if _scipy_imported is False:
- raise ImportError("FigureFactory.create_distplot requires scipy")
-
- @staticmethod
- def _validate_positive_scalars(**kwargs):
- """
- Validates that all values given in key/val pairs are positive.
-
- Accepts kwargs to improve Exception messages.
-
- :raises: (PlotlyError) If any value is < 0 or raises.
- """
- for key, val in kwargs.items():
- try:
- if val <= 0:
- raise ValueError('{} must be > 0, got {}'.format(key, val))
- except TypeError:
- raise exceptions.PlotlyError('{} must be a number, got {}'
- .format(key, val))
-
- @staticmethod
- def _validate_streamline(x, y):
- """
- Streamline-specific validations
-
- Specifically, this checks that x and y are both evenly spaced,
- and that the package numpy is available.
-
- See FigureFactory.create_streamline() for params
-
- :raises: (ImportError) If numpy is not available.
- :raises: (PlotlyError) If x is not evenly spaced.
- :raises: (PlotlyError) If y is not evenly spaced.
- """
- if _numpy_imported is False:
- raise ImportError("FigureFactory.create_streamline requires numpy")
- for index in range(len(x) - 1):
- if ((x[index + 1] - x[index]) - (x[1] - x[0])) > .0001:
- raise exceptions.PlotlyError("x must be a 1 dimensional, "
- "evenly spaced array")
- for index in range(len(y) - 1):
- if ((y[index + 1] - y[index]) -
- (y[1] - y[0])) > .0001:
- raise exceptions.PlotlyError("y must be a 1 dimensional, "
- "evenly spaced array")
-
- @staticmethod
- def _validate_annotated_heatmap(z, x, y, annotation_text):
- """
- Annotated-heatmap-specific validations
-
- Check that if a text matrix is supplied, it has the same
- dimensions as the z matrix.
-
- See FigureFactory.create_annotated_heatmap() for params
-
- :raises: (PlotlyError) If z and text matrices do not have the same
- dimensions.
- """
- if annotation_text is not None and isinstance(annotation_text, list):
- FigureFactory._validate_equal_length(z, annotation_text)
- for lst in range(len(z)):
- if len(z[lst]) != len(annotation_text[lst]):
- raise exceptions.PlotlyError("z and text should have the "
- "same dimensions")
-
- if x:
- if len(x) != len(z[0]):
- raise exceptions.PlotlyError("oops, the x list that you "
- "provided does not match the "
- "width of your z matrix ")
-
- if y:
- if len(y) != len(z):
- raise exceptions.PlotlyError("oops, the y list that you "
- "provided does not match the "
- "length of your z matrix ")
-
- @staticmethod
- def _validate_table(table_text, font_colors):
- """
- Table-specific validations
-
- Check that font_colors is supplied correctly (1, 3, or len(text)
- colors).
-
- :raises: (PlotlyError) If font_colors is supplied incorretly.
-
- See FigureFactory.create_table() for params
- """
- font_colors_len_options = [1, 3, len(table_text)]
- if len(font_colors) not in font_colors_len_options:
- raise exceptions.PlotlyError("Oops, font_colors should be a list "
- "of length 1, 3 or len(text)")
-
- @staticmethod
- def _flatten(array):
- """
- Uses list comprehension to flatten array
-
- :param (array): An iterable to flatten
- :raises (PlotlyError): If iterable is not nested.
- :rtype (list): The flattened list.
- """
- try:
- return [item for sublist in array for item in sublist]
- except TypeError:
- raise exceptions.PlotlyError("Your data array could not be "
- "flattened! Make sure your data is "
- "entered as lists or ndarrays!")
-
- @staticmethod
- def _hex_to_rgb(value):
- """
- Calculates rgb values from a hex color code.
-
- :param (string) value: Hex color string
-
- :rtype (tuple) (r_value, g_value, b_value): tuple of rgb values
- """
- value = value.lstrip('#')
- hex_total_length = len(value)
- rgb_section_length = hex_total_length // 3
- return tuple(int(value[i:i + rgb_section_length], 16)
- for i in range(0, hex_total_length, rgb_section_length))
-
- @staticmethod
- def create_quiver(x, y, u, v, scale=.1, arrow_scale=.3,
- angle=math.pi / 9, **kwargs):
- """
- Returns data for a quiver plot.
-
- :param (list|ndarray) x: x coordinates of the arrow locations
- :param (list|ndarray) y: y coordinates of the arrow locations
- :param (list|ndarray) u: x components of the arrow vectors
- :param (list|ndarray) v: y components of the arrow vectors
- :param (float in [0,1]) scale: scales size of the arrows(ideally to
- avoid overlap). Default = .1
- :param (float in [0,1]) arrow_scale: value multiplied to length of barb
- to get length of arrowhead. Default = .3
- :param (angle in radians) angle: angle of arrowhead. Default = pi/9
- :param kwargs: kwargs passed through plotly.graph_objs.Scatter
- for more information on valid kwargs call
- help(plotly.graph_objs.Scatter)
-
- :rtype (dict): returns a representation of quiver figure.
-
- Example 1: Trivial Quiver
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
-
- import math
-
- # 1 Arrow from (0,0) to (1,1)
- fig = FF.create_quiver(x=[0], y=[0],
- u=[1], v=[1],
- scale=1)
-
- py.plot(fig, filename='quiver')
- ```
-
- Example 2: Quiver plot using meshgrid
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
-
- import numpy as np
- import math
-
- # Add data
- x,y = np.meshgrid(np.arange(0, 2, .2), np.arange(0, 2, .2))
- u = np.cos(x)*y
- v = np.sin(x)*y
-
- #Create quiver
- fig = FF.create_quiver(x, y, u, v)
-
- # Plot
- py.plot(fig, filename='quiver')
- ```
-
- Example 3: Styling the quiver plot
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
- import numpy as np
- import math
-
- # Add data
- x, y = np.meshgrid(np.arange(-np.pi, math.pi, .5),
- np.arange(-math.pi, math.pi, .5))
- u = np.cos(x)*y
- v = np.sin(x)*y
-
- # Create quiver
- fig = FF.create_quiver(x, y, u, v, scale=.2,
- arrow_scale=.3,
- angle=math.pi/6,
- name='Wind Velocity',
- line=Line(width=1))
-
- # Add title to layout
- fig['layout'].update(title='Quiver Plot')
-
- # Plot
- py.plot(fig, filename='quiver')
- ```
- """
- # TODO: protected until #282
- from plotly.graph_objs import graph_objs
- FigureFactory._validate_equal_length(x, y, u, v)
- FigureFactory._validate_positive_scalars(arrow_scale=arrow_scale,
- scale=scale)
-
- barb_x, barb_y = _Quiver(x, y, u, v, scale,
- arrow_scale, angle).get_barbs()
- arrow_x, arrow_y = _Quiver(x, y, u, v, scale,
- arrow_scale, angle).get_quiver_arrows()
- quiver = graph_objs.Scatter(x=barb_x + arrow_x,
- y=barb_y + arrow_y,
- mode='lines', **kwargs)
-
- data = [quiver]
- layout = graph_objs.Layout(hovermode='closest')
-
- return graph_objs.Figure(data=data, layout=layout)
-
- @staticmethod
- def create_streamline(x, y, u, v,
- density=1, angle=math.pi / 9,
- arrow_scale=.09, **kwargs):
- """
- Returns data for a streamline plot.
-
- :param (list|ndarray) x: 1 dimensional, evenly spaced list or array
- :param (list|ndarray) y: 1 dimensional, evenly spaced list or array
- :param (ndarray) u: 2 dimensional array
- :param (ndarray) v: 2 dimensional array
- :param (float|int) density: controls the density of streamlines in
- plot. This is multiplied by 30 to scale similiarly to other
- available streamline functions such as matplotlib.
- Default = 1
- :param (angle in radians) angle: angle of arrowhead. Default = pi/9
- :param (float in [0,1]) arrow_scale: value to scale length of arrowhead
- Default = .09
- :param kwargs: kwargs passed through plotly.graph_objs.Scatter
- for more information on valid kwargs call
- help(plotly.graph_objs.Scatter)
-
- :rtype (dict): returns a representation of streamline figure.
-
- Example 1: Plot simple streamline and increase arrow size
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
-
- import numpy as np
- import math
-
- # Add data
- x = np.linspace(-3, 3, 100)
- y = np.linspace(-3, 3, 100)
- Y, X = np.meshgrid(x, y)
- u = -1 - X**2 + Y
- v = 1 + X - Y**2
- u = u.T # Transpose
- v = v.T # Transpose
-
- # Create streamline
- fig = FF.create_streamline(x, y, u, v,
- arrow_scale=.1)
-
- # Plot
- py.plot(fig, filename='streamline')
- ```
-
- Example 2: from nbviewer.ipython.org/github/barbagroup/AeroPython
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
-
- import numpy as np
- import math
-
- # Add data
- N = 50
- x_start, x_end = -2.0, 2.0
- y_start, y_end = -1.0, 1.0
- x = np.linspace(x_start, x_end, N)
- y = np.linspace(y_start, y_end, N)
- X, Y = np.meshgrid(x, y)
- ss = 5.0
- x_s, y_s = -1.0, 0.0
-
- # Compute the velocity field on the mesh grid
- u_s = ss/(2*np.pi) * (X-x_s)/((X-x_s)**2 + (Y-y_s)**2)
- v_s = ss/(2*np.pi) * (Y-y_s)/((X-x_s)**2 + (Y-y_s)**2)
-
- # Create streamline
- fig = FF.create_streamline(x, y, u_s, v_s,
- density=2, name='streamline')
-
- # Add source point
- point = Scatter(x=[x_s], y=[y_s], mode='markers',
- marker=Marker(size=14), name='source point')
-
- # Plot
- fig['data'].append(point)
- py.plot(fig, filename='streamline')
- ```
- """
- # TODO: protected until #282
- from plotly.graph_objs import graph_objs
- FigureFactory._validate_equal_length(x, y)
- FigureFactory._validate_equal_length(u, v)
- FigureFactory._validate_streamline(x, y)
- FigureFactory._validate_positive_scalars(density=density,
- arrow_scale=arrow_scale)
-
- streamline_x, streamline_y = _Streamline(x, y, u, v,
- density, angle,
- arrow_scale).sum_streamlines()
- arrow_x, arrow_y = _Streamline(x, y, u, v,
- density, angle,
- arrow_scale).get_streamline_arrows()
-
- streamline = graph_objs.Scatter(x=streamline_x + arrow_x,
- y=streamline_y + arrow_y,
- mode='lines', **kwargs)
-
- data = [streamline]
- layout = graph_objs.Layout(hovermode='closest')
-
- return graph_objs.Figure(data=data, layout=layout)
-
- @staticmethod
- def _make_increasing_ohlc(open, high, low, close, dates, **kwargs):
- """
- Makes increasing ohlc sticks
-
- _make_increasing_ohlc() and _make_decreasing_ohlc separate the
- increasing trace from the decreasing trace so kwargs (such as
- color) can be passed separately to increasing or decreasing traces
- when direction is set to 'increasing' or 'decreasing' in
- FigureFactory.create_candlestick()
-
- :param (list) open: opening values
- :param (list) high: high values
- :param (list) low: low values
- :param (list) close: closing values
- :param (list) dates: list of datetime objects. Default: None
- :param kwargs: kwargs to be passed to increasing trace via
- plotly.graph_objs.Scatter.
-
- :rtype (trace) ohlc_incr_data: Scatter trace of all increasing ohlc
- sticks.
- """
- (flat_increase_x,
- flat_increase_y,
- text_increase) = _OHLC(open, high, low, close, dates).get_increase()
-
- if 'name' in kwargs:
- showlegend = True
- else:
- kwargs.setdefault('name', 'Increasing')
- showlegend = False
-
- kwargs.setdefault('line', dict(color=_DEFAULT_INCREASING_COLOR,
- width=1))
- kwargs.setdefault('text', text_increase)
-
- ohlc_incr = dict(type='scatter',
- x=flat_increase_x,
- y=flat_increase_y,
- mode='lines',
- showlegend=showlegend,
- **kwargs)
- return ohlc_incr
-
- @staticmethod
- def _make_decreasing_ohlc(open, high, low, close, dates, **kwargs):
- """
- Makes decreasing ohlc sticks
-
- :param (list) open: opening values
- :param (list) high: high values
- :param (list) low: low values
- :param (list) close: closing values
- :param (list) dates: list of datetime objects. Default: None
- :param kwargs: kwargs to be passed to increasing trace via
- plotly.graph_objs.Scatter.
-
- :rtype (trace) ohlc_decr_data: Scatter trace of all decreasing ohlc
- sticks.
- """
- (flat_decrease_x,
- flat_decrease_y,
- text_decrease) = _OHLC(open, high, low, close, dates).get_decrease()
-
- kwargs.setdefault('line', dict(color=_DEFAULT_DECREASING_COLOR,
- width=1))
- kwargs.setdefault('text', text_decrease)
- kwargs.setdefault('showlegend', False)
- kwargs.setdefault('name', 'Decreasing')
-
- ohlc_decr = dict(type='scatter',
- x=flat_decrease_x,
- y=flat_decrease_y,
- mode='lines',
- **kwargs)
- return ohlc_decr
-
- @staticmethod
- def create_ohlc(open, high, low, close,
- dates=None, direction='both',
- **kwargs):
- """
- BETA function that creates an ohlc chart
-
- :param (list) open: opening values
- :param (list) high: high values
- :param (list) low: low values
- :param (list) close: closing
- :param (list) dates: list of datetime objects. Default: None
- :param (string) direction: direction can be 'increasing', 'decreasing',
- or 'both'. When the direction is 'increasing', the returned figure
- consists of all units where the close value is greater than the
- corresponding open value, and when the direction is 'decreasing',
- the returned figure consists of all units where the close value is
- less than or equal to the corresponding open value. When the
- direction is 'both', both increasing and decreasing units are
- returned. Default: 'both'
- :param kwargs: kwargs passed through plotly.graph_objs.Scatter.
- These kwargs describe other attributes about the ohlc Scatter trace
- such as the color or the legend name. For more information on valid
- kwargs call help(plotly.graph_objs.Scatter)
-
- :rtype (dict): returns a representation of an ohlc chart figure.
-
- Example 1: Simple OHLC chart from a Pandas DataFrame
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
- from datetime import datetime
-
- import pandas.io.data as web
-
- df = web.DataReader("aapl", 'yahoo', datetime(2008, 8, 15), datetime(2008, 10, 15))
- fig = FF.create_ohlc(df.Open, df.High, df.Low, df.Close, dates=df.index)
-
- py.plot(fig, filename='finance/aapl-ohlc')
- ```
-
- Example 2: Add text and annotations to the OHLC chart
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
- from datetime import datetime
-
- import pandas.io.data as web
-
- df = web.DataReader("aapl", 'yahoo', datetime(2008, 8, 15), datetime(2008, 10, 15))
- fig = FF.create_ohlc(df.Open, df.High, df.Low, df.Close, dates=df.index)
-
- # Update the fig - all options here: https://plot.ly/python/reference/#Layout
- fig['layout'].update({
- 'title': 'The Great Recession',
- 'yaxis': {'title': 'AAPL Stock'},
- 'shapes': [{
- 'x0': '2008-09-15', 'x1': '2008-09-15', 'type': 'line',
- 'y0': 0, 'y1': 1, 'xref': 'x', 'yref': 'paper',
- 'line': {'color': 'rgb(40,40,40)', 'width': 0.5}
- }],
- 'annotations': [{
- 'text': "the fall of Lehman Brothers",
- 'x': '2008-09-15', 'y': 1.02,
- 'xref': 'x', 'yref': 'paper',
- 'showarrow': False, 'xanchor': 'left'
- }]
- })
-
- py.plot(fig, filename='finance/aapl-recession-ohlc', validate=False)
- ```
-
- Example 3: Customize the OHLC colors
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
- from plotly.graph_objs import Line, Marker
- from datetime import datetime
-
- import pandas.io.data as web
-
- df = web.DataReader("aapl", 'yahoo', datetime(2008, 1, 1), datetime(2009, 4, 1))
-
- # Make increasing ohlc sticks and customize their color and name
- fig_increasing = FF.create_ohlc(df.Open, df.High, df.Low, df.Close, dates=df.index,
- direction='increasing', name='AAPL',
- line=Line(color='rgb(150, 200, 250)'))
-
- # Make decreasing ohlc sticks and customize their color and name
- fig_decreasing = FF.create_ohlc(df.Open, df.High, df.Low, df.Close, dates=df.index,
- direction='decreasing',
- line=Line(color='rgb(128, 128, 128)'))
-
- # Initialize the figure
- fig = fig_increasing
-
- # Add decreasing data with .extend()
- fig['data'].extend(fig_decreasing['data'])
-
- py.iplot(fig, filename='finance/aapl-ohlc-colors', validate=False)
- ```
-
- Example 4: OHLC chart with datetime objects
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
-
- from datetime import datetime
-
- # Add data
- open_data = [33.0, 33.3, 33.5, 33.0, 34.1]
- high_data = [33.1, 33.3, 33.6, 33.2, 34.8]
- low_data = [32.7, 32.7, 32.8, 32.6, 32.8]
- close_data = [33.0, 32.9, 33.3, 33.1, 33.1]
- dates = [datetime(year=2013, month=10, day=10),
- datetime(year=2013, month=11, day=10),
- datetime(year=2013, month=12, day=10),
- datetime(year=2014, month=1, day=10),
- datetime(year=2014, month=2, day=10)]
-
- # Create ohlc
- fig = FF.create_ohlc(open_data, high_data,
- low_data, close_data, dates=dates)
-
- py.iplot(fig, filename='finance/simple-ohlc', validate=False)
- ```
- """
- # TODO: protected until #282
- from plotly.graph_objs import graph_objs
- if dates is not None:
- FigureFactory._validate_equal_length(open, high, low, close, dates)
- else:
- FigureFactory._validate_equal_length(open, high, low, close)
- FigureFactory._validate_ohlc(open, high, low, close, direction,
- **kwargs)
-
- if direction is 'increasing':
- ohlc_incr = FigureFactory._make_increasing_ohlc(open, high,
- low, close,
- dates, **kwargs)
- data = [ohlc_incr]
- elif direction is 'decreasing':
- ohlc_decr = FigureFactory._make_decreasing_ohlc(open, high,
- low, close,
- dates, **kwargs)
- data = [ohlc_decr]
- else:
- ohlc_incr = FigureFactory._make_increasing_ohlc(open, high,
- low, close,
- dates, **kwargs)
- ohlc_decr = FigureFactory._make_decreasing_ohlc(open, high,
- low, close,
- dates, **kwargs)
- data = [ohlc_incr, ohlc_decr]
-
- layout = graph_objs.Layout(xaxis=dict(zeroline=False),
- hovermode='closest')
-
- return graph_objs.Figure(data=data, layout=layout)
-
- @staticmethod
- def _make_increasing_candle(open, high, low, close, dates, **kwargs):
- """
- Makes boxplot trace for increasing candlesticks
-
- _make_increasing_candle() and _make_decreasing_candle separate the
- increasing traces from the decreasing traces so kwargs (such as
- color) can be passed separately to increasing or decreasing traces
- when direction is set to 'increasing' or 'decreasing' in
- FigureFactory.create_candlestick()
-
- :param (list) open: opening values
- :param (list) high: high values
- :param (list) low: low values
- :param (list) close: closing values
- :param (list) dates: list of datetime objects. Default: None
- :param kwargs: kwargs to be passed to increasing trace via
- plotly.graph_objs.Scatter.
-
- :rtype (list) candle_incr_data: list of the box trace for
- increasing candlesticks.
- """
- increase_x, increase_y = _Candlestick(
- open, high, low, close, dates, **kwargs).get_candle_increase()
-
- if 'line' in kwargs:
- kwargs.setdefault('fillcolor', kwargs['line']['color'])
- else:
- kwargs.setdefault('fillcolor', _DEFAULT_INCREASING_COLOR)
- if 'name' in kwargs:
- kwargs.setdefault('showlegend', True)
- else:
- kwargs.setdefault('showlegend', False)
- kwargs.setdefault('name', 'Increasing')
- kwargs.setdefault('line', dict(color=_DEFAULT_INCREASING_COLOR))
-
- candle_incr_data = dict(type='box',
- x=increase_x,
- y=increase_y,
- whiskerwidth=0,
- boxpoints=False,
- **kwargs)
-
- return [candle_incr_data]
-
- @staticmethod
- def _make_decreasing_candle(open, high, low, close, dates, **kwargs):
- """
- Makes boxplot trace for decreasing candlesticks
-
- :param (list) open: opening values
- :param (list) high: high values
- :param (list) low: low values
- :param (list) close: closing values
- :param (list) dates: list of datetime objects. Default: None
- :param kwargs: kwargs to be passed to decreasing trace via
- plotly.graph_objs.Scatter.
-
- :rtype (list) candle_decr_data: list of the box trace for
- decreasing candlesticks.
- """
-
- decrease_x, decrease_y = _Candlestick(
- open, high, low, close, dates, **kwargs).get_candle_decrease()
-
- if 'line' in kwargs:
- kwargs.setdefault('fillcolor', kwargs['line']['color'])
- else:
- kwargs.setdefault('fillcolor', _DEFAULT_DECREASING_COLOR)
- kwargs.setdefault('showlegend', False)
- kwargs.setdefault('line', dict(color=_DEFAULT_DECREASING_COLOR))
- kwargs.setdefault('name', 'Decreasing')
-
- candle_decr_data = dict(type='box',
- x=decrease_x,
- y=decrease_y,
- whiskerwidth=0,
- boxpoints=False,
- **kwargs)
-
- return [candle_decr_data]
-
- @staticmethod
- def create_candlestick(open, high, low, close,
- dates=None, direction='both', **kwargs):
- """
- BETA function that creates a candlestick chart
-
- :param (list) open: opening values
- :param (list) high: high values
- :param (list) low: low values
- :param (list) close: closing values
- :param (list) dates: list of datetime objects. Default: None
- :param (string) direction: direction can be 'increasing', 'decreasing',
- or 'both'. When the direction is 'increasing', the returned figure
- consists of all candlesticks where the close value is greater than
- the corresponding open value, and when the direction is
- 'decreasing', the returned figure consists of all candlesticks
- where the close value is less than or equal to the corresponding
- open value. When the direction is 'both', both increasing and
- decreasing candlesticks are returned. Default: 'both'
- :param kwargs: kwargs passed through plotly.graph_objs.Scatter.
- These kwargs describe other attributes about the ohlc Scatter trace
- such as the color or the legend name. For more information on valid
- kwargs call help(plotly.graph_objs.Scatter)
-
- :rtype (dict): returns a representation of candlestick chart figure.
-
- Example 1: Simple candlestick chart from a Pandas DataFrame
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
- from datetime import datetime
-
- import pandas.io.data as web
-
- df = web.DataReader("aapl", 'yahoo', datetime(2007, 10, 1), datetime(2009, 4, 1))
- fig = FF.create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index)
- py.plot(fig, filename='finance/aapl-candlestick', validate=False)
- ```
-
- Example 2: Add text and annotations to the candlestick chart
- ```
- fig = FF.create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index)
- # Update the fig - all options here: https://plot.ly/python/reference/#Layout
- fig['layout'].update({
- 'title': 'The Great Recession',
- 'yaxis': {'title': 'AAPL Stock'},
- 'shapes': [{
- 'x0': '2007-12-01', 'x1': '2007-12-01',
- 'y0': 0, 'y1': 1, 'xref': 'x', 'yref': 'paper',
- 'line': {'color': 'rgb(30,30,30)', 'width': 1}
- }],
- 'annotations': [{
- 'x': '2007-12-01', 'y': 0.05, 'xref': 'x', 'yref': 'paper',
- 'showarrow': False, 'xanchor': 'left',
- 'text': 'Official start of the recession'
- }]
- })
- py.plot(fig, filename='finance/aapl-recession-candlestick', validate=False)
- ```
-
- Example 3: Customize the candlestick colors
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
- from plotly.graph_objs import Line, Marker
- from datetime import datetime
-
- import pandas.io.data as web
-
- df = web.DataReader("aapl", 'yahoo', datetime(2008, 1, 1), datetime(2009, 4, 1))
-
- # Make increasing candlesticks and customize their color and name
- fig_increasing = FF.create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index,
- direction='increasing', name='AAPL',
- marker=Marker(color='rgb(150, 200, 250)'),
- line=Line(color='rgb(150, 200, 250)'))
-
- # Make decreasing candlesticks and customize their color and name
- fig_decreasing = FF.create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index,
- direction='decreasing',
- marker=Marker(color='rgb(128, 128, 128)'),
- line=Line(color='rgb(128, 128, 128)'))
-
- # Initialize the figure
- fig = fig_increasing
-
- # Add decreasing data with .extend()
- fig['data'].extend(fig_decreasing['data'])
-
- py.iplot(fig, filename='finance/aapl-candlestick-custom', validate=False)
- ```
-
- Example 4: Candlestick chart with datetime objects
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
-
- from datetime import datetime
-
- # Add data
- open_data = [33.0, 33.3, 33.5, 33.0, 34.1]
- high_data = [33.1, 33.3, 33.6, 33.2, 34.8]
- low_data = [32.7, 32.7, 32.8, 32.6, 32.8]
- close_data = [33.0, 32.9, 33.3, 33.1, 33.1]
- dates = [datetime(year=2013, month=10, day=10),
- datetime(year=2013, month=11, day=10),
- datetime(year=2013, month=12, day=10),
- datetime(year=2014, month=1, day=10),
- datetime(year=2014, month=2, day=10)]
-
- # Create ohlc
- fig = FF.create_candlestick(open_data, high_data,
- low_data, close_data, dates=dates)
-
- py.iplot(fig, filename='finance/simple-candlestick', validate=False)
- ```
- """
- # TODO: protected until #282
- from plotly.graph_objs import graph_objs
- if dates is not None:
- FigureFactory._validate_equal_length(open, high, low, close, dates)
- else:
- FigureFactory._validate_equal_length(open, high, low, close)
- FigureFactory._validate_ohlc(open, high, low, close, direction,
- **kwargs)
-
- if direction is 'increasing':
- candle_incr_data = FigureFactory._make_increasing_candle(
- open, high, low, close, dates, **kwargs)
- data = candle_incr_data
- elif direction is 'decreasing':
- candle_decr_data = FigureFactory._make_decreasing_candle(
- open, high, low, close, dates, **kwargs)
- data = candle_decr_data
- else:
- candle_incr_data = FigureFactory._make_increasing_candle(
- open, high, low, close, dates, **kwargs)
- candle_decr_data = FigureFactory._make_decreasing_candle(
- open, high, low, close, dates, **kwargs)
- data = candle_incr_data + candle_decr_data
-
- layout = graph_objs.Layout()
- return graph_objs.Figure(data=data, layout=layout)
-
- @staticmethod
- def create_distplot(hist_data, group_labels,
- bin_size=1., curve_type='kde',
- colors=[], rug_text=[], histnorm=DEFAULT_HISTNORM,
- show_hist=True, show_curve=True,
- show_rug=True):
- """
- BETA function that creates a distplot similar to seaborn.distplot
-
- The distplot can be composed of all or any combination of the following
- 3 components: (1) histogram, (2) curve: (a) kernel density estimation
- or (b) normal curve, and (3) rug plot. Additionally, multiple distplots
- (from multiple datasets) can be created in the same plot.
-
- :param (list[list]) hist_data: Use list of lists to plot multiple data
- sets on the same plot.
- :param (list[str]) group_labels: Names for each data set.
- :param (list[float]|float) bin_size: Size of histogram bins.
- Default = 1.
- :param (str) curve_type: 'kde' or 'normal'. Default = 'kde'
- :param (str) histnorm: 'probability density' or 'probability'
- Default = 'probability density'
- :param (bool) show_hist: Add histogram to distplot? Default = True
- :param (bool) show_curve: Add curve to distplot? Default = True
- :param (bool) show_rug: Add rug to distplot? Default = True
- :param (list[str]) colors: Colors for traces.
- :param (list[list]) rug_text: Hovertext values for rug_plot,
- :return (dict): Representation of a distplot figure.
-
- Example 1: Simple distplot of 1 data set
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
-
- hist_data = [[1.1, 1.1, 2.5, 3.0, 3.5,
- 3.5, 4.1, 4.4, 4.5, 4.5,
- 5.0, 5.0, 5.2, 5.5, 5.5,
- 5.5, 5.5, 5.5, 6.1, 7.0]]
-
- group_labels = ['distplot example']
-
- fig = FF.create_distplot(hist_data, group_labels)
-
- url = py.plot(fig, filename='Simple distplot', validate=False)
- ```
-
- Example 2: Two data sets and added rug text
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
-
- # Add histogram data
- hist1_x = [0.8, 1.2, 0.2, 0.6, 1.6,
- -0.9, -0.07, 1.95, 0.9, -0.2,
- -0.5, 0.3, 0.4, -0.37, 0.6]
- hist2_x = [0.8, 1.5, 1.5, 0.6, 0.59,
- 1.0, 0.8, 1.7, 0.5, 0.8,
- -0.3, 1.2, 0.56, 0.3, 2.2]
-
- # Group data together
- hist_data = [hist1_x, hist2_x]
-
- group_labels = ['2012', '2013']
-
- # Add text
- rug_text_1 = ['a1', 'b1', 'c1', 'd1', 'e1',
- 'f1', 'g1', 'h1', 'i1', 'j1',
- 'k1', 'l1', 'm1', 'n1', 'o1']
-
- rug_text_2 = ['a2', 'b2', 'c2', 'd2', 'e2',
- 'f2', 'g2', 'h2', 'i2', 'j2',
- 'k2', 'l2', 'm2', 'n2', 'o2']
-
- # Group text together
- rug_text_all = [rug_text_1, rug_text_2]
-
- # Create distplot
- fig = FF.create_distplot(
- hist_data, group_labels, rug_text=rug_text_all, bin_size=.2)
-
- # Add title
- fig['layout'].update(title='Dist Plot')
-
- # Plot!
- url = py.plot(fig, filename='Distplot with rug text', validate=False)
- ```
-
- Example 3: Plot with normal curve and hide rug plot
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
- import numpy as np
-
- x1 = np.random.randn(190)
- x2 = np.random.randn(200)+1
- x3 = np.random.randn(200)-1
- x4 = np.random.randn(210)+2
-
- hist_data = [x1, x2, x3, x4]
- group_labels = ['2012', '2013', '2014', '2015']
-
- fig = FF.create_distplot(
- hist_data, group_labels, curve_type='normal',
- show_rug=False, bin_size=.4)
-
- url = py.plot(fig, filename='hist and normal curve', validate=False)
-
- Example 4: Distplot with Pandas
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
- import numpy as np
- import pandas as pd
-
- df = pd.DataFrame({'2012': np.random.randn(200),
- '2013': np.random.randn(200)+1})
- py.iplot(FF.create_distplot([df[c] for c in df.columns], df.columns),
- filename='examples/distplot with pandas',
- validate=False)
- ```
- """
- # TODO: protected until #282
- from plotly.graph_objs import graph_objs
- FigureFactory._validate_distplot(hist_data, curve_type)
- FigureFactory._validate_equal_length(hist_data, group_labels)
-
- if isinstance(bin_size, (float, int)):
- bin_size = [bin_size]*len(hist_data)
-
- hist = _Distplot(
- hist_data, histnorm, group_labels, bin_size,
- curve_type, colors, rug_text,
- show_hist, show_curve).make_hist()
-
- if curve_type == 'normal':
- curve = _Distplot(
- hist_data, histnorm, group_labels, bin_size,
- curve_type, colors, rug_text,
- show_hist, show_curve).make_normal()
- else:
- curve = _Distplot(
- hist_data, histnorm, group_labels, bin_size,
- curve_type, colors, rug_text,
- show_hist, show_curve).make_kde()
-
- rug = _Distplot(
- hist_data, histnorm, group_labels, bin_size,
- curve_type, colors, rug_text,
- show_hist, show_curve).make_rug()
-
- data = []
- if show_hist:
- data.append(hist)
- if show_curve:
- data.append(curve)
- if show_rug:
- data.append(rug)
- layout = graph_objs.Layout(
- barmode='overlay',
- hovermode='closest',
- legend=dict(traceorder='reversed'),
- xaxis1=dict(domain=[0.0, 1.0],
- anchor='y2',
- zeroline=False),
- yaxis1=dict(domain=[0.35, 1],
- anchor='free',
- position=0.0),
- yaxis2=dict(domain=[0, 0.25],
- anchor='x1',
- dtick=1,
- showticklabels=False))
- else:
- layout = graph_objs.Layout(
- barmode='overlay',
- hovermode='closest',
- legend=dict(traceorder='reversed'),
- xaxis1=dict(domain=[0.0, 1.0],
- anchor='y2',
- zeroline=False),
- yaxis1=dict(domain=[0., 1],
- anchor='free',
- position=0.0))
-
- data = sum(data, [])
- return graph_objs.Figure(data=data, layout=layout)
-
- @staticmethod
- def create_dendrogram(X, orientation="bottom", labels=None,
- colorscale=None, distfun=None,
- linkagefun=lambda x: sch.linkage(x, 'complete')):
- """
- BETA function that returns a dendrogram Plotly figure object.
-
- :param (ndarray) X: Matrix of observations as array of arrays
- :param (str) orientation: 'top', 'right', 'bottom', or 'left'
- :param (list) labels: List of axis category labels(observation labels)
- :param (list) colorscale: Optional colorscale for dendrogram tree
- :param (function) distfun: Function to compute the pairwise distance from the observations
- :param (function) linkagefun: Function to compute the linkage matrix from the pairwise distances
-
- clusters
-
- Example 1: Simple bottom oriented dendrogram
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
-
- import numpy as np
-
- X = np.random.rand(10,10)
- dendro = FF.create_dendrogram(X)
- plot_url = py.plot(dendro, filename='simple-dendrogram')
-
- ```
-
- Example 2: Dendrogram to put on the left of the heatmap
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
-
- import numpy as np
-
- X = np.random.rand(5,5)
- names = ['Jack', 'Oxana', 'John', 'Chelsea', 'Mark']
- dendro = FF.create_dendrogram(X, orientation='right', labels=names)
- dendro['layout'].update({'width':700, 'height':500})
-
- py.iplot(dendro, filename='vertical-dendrogram')
- ```
-
- Example 3: Dendrogram with Pandas
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
-
- import numpy as np
- import pandas as pd
-
- Index= ['A','B','C','D','E','F','G','H','I','J']
- df = pd.DataFrame(abs(np.random.randn(10, 10)), index=Index)
- fig = FF.create_dendrogram(df, labels=Index)
- url = py.plot(fig, filename='pandas-dendrogram')
- ```
- """
- dependencies = (_scipy_imported and _scipy__spatial_imported and
- _scipy__cluster__hierarchy_imported)
-
- if dependencies is False:
- raise ImportError("FigureFactory.create_dendrogram requires scipy, \
- scipy.spatial and scipy.hierarchy")
-
- s = X.shape
- if len(s) != 2:
- exceptions.PlotlyError("X should be 2-dimensional array.")
-
- if distfun is None:
- distfun = scs.distance.pdist
-
- dendrogram = _Dendrogram(X, orientation, labels, colorscale,
- distfun=distfun, linkagefun=linkagefun)
-
- return {'layout': dendrogram.layout,
- 'data': dendrogram.data}
-
- @staticmethod
- def create_annotated_heatmap(z, x=None, y=None, annotation_text=None,
- colorscale='RdBu', font_colors=None,
- showscale=False, reversescale=False,
- **kwargs):
- """
- BETA function that creates annotated heatmaps
-
- This function adds annotations to each cell of the heatmap.
-
- :param (list[list]|ndarray) z: z matrix to create heatmap.
- :param (list) x: x axis labels.
- :param (list) y: y axis labels.
- :param (list[list]|ndarray) annotation_text: Text strings for
- annotations. Should have the same dimensions as the z matrix. If no
- text is added, the values of the z matrix are annotated. Default =
- z matrix values.
- :param (list|str) colorscale: heatmap colorscale.
- :param (list) font_colors: List of two color strings: [min_text_color,
- max_text_color] where min_text_color is applied to annotations for
- heatmap values < (max_value - min_value)/2. If font_colors is not
- defined, the colors are defined logically as black or white
- depending on the heatmap's colorscale.
- :param (bool) showscale: Display colorscale. Default = False
- :param kwargs: kwargs passed through plotly.graph_objs.Heatmap.
- These kwargs describe other attributes about the annotated Heatmap
- trace such as the colorscale. For more information on valid kwargs
- call help(plotly.graph_objs.Heatmap)
-
- Example 1: Simple annotated heatmap with default configuration
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
-
- z = [[0.300000, 0.00000, 0.65, 0.300000],
- [1, 0.100005, 0.45, 0.4300],
- [0.300000, 0.00000, 0.65, 0.300000],
- [1, 0.100005, 0.45, 0.00000]]
-
- figure = FF.create_annotated_heatmap(z)
- py.iplot(figure)
- ```
- """
- # TODO: protected until #282
- from plotly.graph_objs import graph_objs
-
- # Avoiding mutables in the call signature
- font_colors = font_colors if font_colors is not None else []
- FigureFactory._validate_annotated_heatmap(z, x, y, annotation_text)
- annotations = _AnnotatedHeatmap(z, x, y, annotation_text,
- colorscale, font_colors, reversescale,
- **kwargs).make_annotations()
-
- if x or y:
- trace = dict(type='heatmap', z=z, x=x, y=y, colorscale=colorscale,
- showscale=showscale, **kwargs)
- layout = dict(annotations=annotations,
- xaxis=dict(ticks='', dtick=1, side='top',
- gridcolor='rgb(0, 0, 0)'),
- yaxis=dict(ticks='', dtick=1, ticksuffix=' '))
- else:
- trace = dict(type='heatmap', z=z, colorscale=colorscale,
- showscale=showscale, **kwargs)
- layout = dict(annotations=annotations,
- xaxis=dict(ticks='', side='top',
- gridcolor='rgb(0, 0, 0)',
- showticklabels=False),
- yaxis=dict(ticks='', ticksuffix=' ',
- showticklabels=False))
-
- data = [trace]
-
- return graph_objs.Figure(data=data, layout=layout)
-
- @staticmethod
- def create_table(table_text, colorscale=None, font_colors=None,
- index=False, index_title='', annotation_offset=.45,
- height_constant=30, hoverinfo='none', **kwargs):
- """
- BETA function that creates data tables
-
- :param (pandas.Dataframe | list[list]) text: data for table.
- :param (str|list[list]) colorscale: Colorscale for table where the
- color at value 0 is the header color, .5 is the first table color
- and 1 is the second table color. (Set .5 and 1 to avoid the striped
- table effect). Default=[[0, '#66b2ff'], [.5, '#d9d9d9'],
- [1, '#ffffff']]
- :param (list) font_colors: Color for fonts in table. Can be a single
- color, three colors, or a color for each row in the table.
- Default=['#000000'] (black text for the entire table)
- :param (int) height_constant: Constant multiplied by # of rows to
- create table height. Default=30.
- :param (bool) index: Create (header-colored) index column index from
- Pandas dataframe or list[0] for each list in text. Default=False.
- :param (string) index_title: Title for index column. Default=''.
- :param kwargs: kwargs passed through plotly.graph_objs.Heatmap.
- These kwargs describe other attributes about the annotated Heatmap
- trace such as the colorscale. For more information on valid kwargs
- call help(plotly.graph_objs.Heatmap)
-
- Example 1: Simple Plotly Table
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
-
- text = [['Country', 'Year', 'Population'],
- ['US', 2000, 282200000],
- ['Canada', 2000, 27790000],
- ['US', 2010, 309000000],
- ['Canada', 2010, 34000000]]
-
- table = FF.create_table(text)
- py.iplot(table)
- ```
-
- Example 2: Table with Custom Coloring
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
-
- text = [['Country', 'Year', 'Population'],
- ['US', 2000, 282200000],
- ['Canada', 2000, 27790000],
- ['US', 2010, 309000000],
- ['Canada', 2010, 34000000]]
-
- table = FF.create_table(text,
- colorscale=[[0, '#000000'],
- [.5, '#80beff'],
- [1, '#cce5ff']],
- font_colors=['#ffffff', '#000000',
- '#000000'])
- py.iplot(table)
- ```
- Example 3: Simple Plotly Table with Pandas
- ```
- import plotly.plotly as py
- from plotly.tools import FigureFactory as FF
-
- import pandas as pd
-
- df = pd.read_csv('http://www.stat.ubc.ca/~jenny/notOcto/STAT545A/examples/gapminder/data/gapminderDataFiveYear.txt', sep='\t')
- df_p = df[0:25]
-
- table_simple = FF.create_table(df_p)
- py.iplot(table_simple)
- ```
- """
- # TODO: protected until #282
- from plotly.graph_objs import graph_objs
-
- # Avoiding mutables in the call signature
- colorscale = \
- colorscale if colorscale is not None else [[0, '#00083e'],
- [.5, '#ededee'],
- [1, '#ffffff']]
- font_colors = font_colors if font_colors is not None else ['#ffffff',
- '#000000',
- '#000000']
-
- FigureFactory._validate_table(table_text, font_colors)
- table_matrix = _Table(table_text, colorscale, font_colors, index,
- index_title, annotation_offset,
- **kwargs).get_table_matrix()
- annotations = _Table(table_text, colorscale, font_colors, index,
- index_title, annotation_offset,
- **kwargs).make_table_annotations()
-
- trace = dict(type='heatmap', z=table_matrix, opacity=.75,
- colorscale=colorscale, showscale=False,
- hoverinfo=hoverinfo, **kwargs)
-
- data = [trace]
- layout = dict(annotations=annotations,
- height=len(table_matrix)*height_constant + 50,
- margin=dict(t=0, b=0, r=0, l=0),
- yaxis=dict(autorange='reversed', zeroline=False,
- gridwidth=2, ticks='', dtick=1, tick0=.5,
- showticklabels=False),
- xaxis=dict(zeroline=False, gridwidth=2, ticks='',
- dtick=1, tick0=-0.5, showticklabels=False))
- return graph_objs.Figure(data=data, layout=layout)
-
-
-class _Quiver(FigureFactory):
- """
- Refer to FigureFactory.create_quiver() for docstring
- """
- def __init__(self, x, y, u, v,
- scale, arrow_scale, angle, **kwargs):
- try:
- x = FigureFactory._flatten(x)
- except exceptions.PlotlyError:
- pass
-
- try:
- y = FigureFactory._flatten(y)
- except exceptions.PlotlyError:
- pass
-
- try:
- u = FigureFactory._flatten(u)
- except exceptions.PlotlyError:
- pass
-
- try:
- v = FigureFactory._flatten(v)
- except exceptions.PlotlyError:
- pass
-
- self.x = x
- self.y = y
- self.u = u
- self.v = v
- self.scale = scale
- self.arrow_scale = arrow_scale
- self.angle = angle
- self.end_x = []
- self.end_y = []
- self.scale_uv()
- barb_x, barb_y = self.get_barbs()
- arrow_x, arrow_y = self.get_quiver_arrows()
-
- def scale_uv(self):
- """
- Scales u and v to avoid overlap of the arrows.
-
- u and v are added to x and y to get the
- endpoints of the arrows so a smaller scale value will
- result in less overlap of arrows.
- """
- self.u = [i * self.scale for i in self.u]
- self.v = [i * self.scale for i in self.v]
-
- def get_barbs(self):
- """
- Creates x and y startpoint and endpoint pairs
-
- After finding the endpoint of each barb this zips startpoint and
- endpoint pairs to create 2 lists: x_values for barbs and y values
- for barbs
-
- :rtype: (list, list) barb_x, barb_y: list of startpoint and endpoint
- x_value pairs separated by a None to create the barb of the arrow,
- and list of startpoint and endpoint y_value pairs separated by a
- None to create the barb of the arrow.
- """
- self.end_x = [i + j for i, j in zip(self.x, self.u)]
- self.end_y = [i + j for i, j in zip(self.y, self.v)]
- empty = [None] * len(self.x)
- barb_x = FigureFactory._flatten(zip(self.x, self.end_x, empty))
- barb_y = FigureFactory._flatten(zip(self.y, self.end_y, empty))
- return barb_x, barb_y
-
- def get_quiver_arrows(self):
- """
- Creates lists of x and y values to plot the arrows
-
- Gets length of each barb then calculates the length of each side of
- the arrow. Gets angle of barb and applies angle to each side of the
- arrowhead. Next uses arrow_scale to scale the length of arrowhead and
- creates x and y values for arrowhead point1 and point2. Finally x and y
- values for point1, endpoint and point2s for each arrowhead are
- separated by a None and zipped to create lists of x and y values for
- the arrows.
-
- :rtype: (list, list) arrow_x, arrow_y: list of point1, endpoint, point2
- x_values separated by a None to create the arrowhead and list of
- point1, endpoint, point2 y_values separated by a None to create
- the barb of the arrow.
- """
- dif_x = [i - j for i, j in zip(self.end_x, self.x)]
- dif_y = [i - j for i, j in zip(self.end_y, self.y)]
-
- # Get barb lengths(default arrow length = 30% barb length)
- barb_len = [None] * len(self.x)
- for index in range(len(barb_len)):
- barb_len[index] = math.hypot(dif_x[index], dif_y[index])
-
- # Make arrow lengths
- arrow_len = [None] * len(self.x)
- arrow_len = [i * self.arrow_scale for i in barb_len]
-
- # Get barb angles
- barb_ang = [None] * len(self.x)
- for index in range(len(barb_ang)):
- barb_ang[index] = math.atan2(dif_y[index], dif_x[index])
-
- # Set angles to create arrow
- ang1 = [i + self.angle for i in barb_ang]
- ang2 = [i - self.angle for i in barb_ang]
-
- cos_ang1 = [None] * len(ang1)
- for index in range(len(ang1)):
- cos_ang1[index] = math.cos(ang1[index])
- seg1_x = [i * j for i, j in zip(arrow_len, cos_ang1)]
-
- sin_ang1 = [None] * len(ang1)
- for index in range(len(ang1)):
- sin_ang1[index] = math.sin(ang1[index])
- seg1_y = [i * j for i, j in zip(arrow_len, sin_ang1)]
-
- cos_ang2 = [None] * len(ang2)
- for index in range(len(ang2)):
- cos_ang2[index] = math.cos(ang2[index])
- seg2_x = [i * j for i, j in zip(arrow_len, cos_ang2)]
-
- sin_ang2 = [None] * len(ang2)
- for index in range(len(ang2)):
- sin_ang2[index] = math.sin(ang2[index])
- seg2_y = [i * j for i, j in zip(arrow_len, sin_ang2)]
-
- # Set coordinates to create arrow
- for index in range(len(self.end_x)):
- point1_x = [i - j for i, j in zip(self.end_x, seg1_x)]
- point1_y = [i - j for i, j in zip(self.end_y, seg1_y)]
- point2_x = [i - j for i, j in zip(self.end_x, seg2_x)]
- point2_y = [i - j for i, j in zip(self.end_y, seg2_y)]
-
- # Combine lists to create arrow
- empty = [None] * len(self.end_x)
- arrow_x = FigureFactory._flatten(zip(point1_x, self.end_x,
- point2_x, empty))
- arrow_y = FigureFactory._flatten(zip(point1_y, self.end_y,
- point2_y, empty))
- return arrow_x, arrow_y
-
-
-class _Streamline(FigureFactory):
- """
- Refer to FigureFactory.create_streamline() for docstring
- """
- def __init__(self, x, y, u, v,
- density, angle,
- arrow_scale, **kwargs):
- self.x = np.array(x)
- self.y = np.array(y)
- self.u = np.array(u)
- self.v = np.array(v)
- self.angle = angle
- self.arrow_scale = arrow_scale
- self.density = int(30 * density) # Scale similarly to other functions
- self.delta_x = self.x[1] - self.x[0]
- self.delta_y = self.y[1] - self.y[0]
- self.val_x = self.x
- self.val_y = self.y
-
- # Set up spacing
- self.blank = np.zeros((self.density, self.density))
- self.spacing_x = len(self.x) / float(self.density - 1)
- self.spacing_y = len(self.y) / float(self.density - 1)
- self.trajectories = []
-
- # Rescale speed onto axes-coordinates
- self.u = self.u / (self.x[-1] - self.x[0])
- self.v = self.v / (self.y[-1] - self.y[0])
- self.speed = np.sqrt(self.u ** 2 + self.v ** 2)
-
- # Rescale u and v for integrations.
- self.u *= len(self.x)
- self.v *= len(self.y)
- self.st_x = []
- self.st_y = []
- self.get_streamlines()
- streamline_x, streamline_y = self.sum_streamlines()
- arrows_x, arrows_y = self.get_streamline_arrows()
-
- def blank_pos(self, xi, yi):
- """
- Set up positions for trajectories to be used with rk4 function.
- """
- return (int((xi / self.spacing_x) + 0.5),
- int((yi / self.spacing_y) + 0.5))
-
- def value_at(self, a, xi, yi):
- """
- Set up for RK4 function, based on Bokeh's streamline code
- """
- if isinstance(xi, np.ndarray):
- self.x = xi.astype(np.int)
- self.y = yi.astype(np.int)
- else:
- self.val_x = np.int(xi)
- self.val_y = np.int(yi)
- a00 = a[self.val_y, self.val_x]
- a01 = a[self.val_y, self.val_x + 1]
- a10 = a[self.val_y + 1, self.val_x]
- a11 = a[self.val_y + 1, self.val_x + 1]
- xt = xi - self.val_x
- yt = yi - self.val_y
- a0 = a00 * (1 - xt) + a01 * xt
- a1 = a10 * (1 - xt) + a11 * xt
- return a0 * (1 - yt) + a1 * yt
-
- def rk4_integrate(self, x0, y0):
- """
- RK4 forward and back trajectories from the initial conditions.
-
- Adapted from Bokeh's streamline -uses Runge-Kutta method to fill
- x and y trajectories then checks length of traj (s in units of axes)
- """
- def f(xi, yi):
- dt_ds = 1. / self.value_at(self.speed, xi, yi)
- ui = self.value_at(self.u, xi, yi)
- vi = self.value_at(self.v, xi, yi)
- return ui * dt_ds, vi * dt_ds
-
- def g(xi, yi):
- dt_ds = 1. / self.value_at(self.speed, xi, yi)
- ui = self.value_at(self.u, xi, yi)
- vi = self.value_at(self.v, xi, yi)
- return -ui * dt_ds, -vi * dt_ds
-
- check = lambda xi, yi: (0 <= xi < len(self.x) - 1 and
- 0 <= yi < len(self.y) - 1)
- xb_changes = []
- yb_changes = []
-
- def rk4(x0, y0, f):
- ds = 0.01
- stotal = 0
- xi = x0
- yi = y0
- xb, yb = self.blank_pos(xi, yi)
- xf_traj = []
- yf_traj = []
- while check(xi, yi):
- xf_traj.append(xi)
- yf_traj.append(yi)
- try:
- k1x, k1y = f(xi, yi)
- k2x, k2y = f(xi + .5 * ds * k1x, yi + .5 * ds * k1y)
- k3x, k3y = f(xi + .5 * ds * k2x, yi + .5 * ds * k2y)
- k4x, k4y = f(xi + ds * k3x, yi + ds * k3y)
- except IndexError:
- break
- xi += ds * (k1x + 2 * k2x + 2 * k3x + k4x) / 6.
- yi += ds * (k1y + 2 * k2y + 2 * k3y + k4y) / 6.
- if not check(xi, yi):
- break
- stotal += ds
- new_xb, new_yb = self.blank_pos(xi, yi)
- if new_xb != xb or new_yb != yb:
- if self.blank[new_yb, new_xb] == 0:
- self.blank[new_yb, new_xb] = 1
- xb_changes.append(new_xb)
- yb_changes.append(new_yb)
- xb = new_xb
- yb = new_yb
- else:
- break
- if stotal > 2:
- break
- return stotal, xf_traj, yf_traj
-
- sf, xf_traj, yf_traj = rk4(x0, y0, f)
- sb, xb_traj, yb_traj = rk4(x0, y0, g)
- stotal = sf + sb
- x_traj = xb_traj[::-1] + xf_traj[1:]
- y_traj = yb_traj[::-1] + yf_traj[1:]
-
- if len(x_traj) < 1:
- return None
- if stotal > .2:
- initxb, inityb = self.blank_pos(x0, y0)
- self.blank[inityb, initxb] = 1
- return x_traj, y_traj
- else:
- for xb, yb in zip(xb_changes, yb_changes):
- self.blank[yb, xb] = 0
- return None
-
- def traj(self, xb, yb):
- """
- Integrate trajectories
-
- :param (int) xb: results of passing xi through self.blank_pos
- :param (int) xy: results of passing yi through self.blank_pos
-
- Calculate each trajectory based on rk4 integrate method.
- """
-
- if xb < 0 or xb >= self.density or yb < 0 or yb >= self.density:
- return
- if self.blank[yb, xb] == 0:
- t = self.rk4_integrate(xb * self.spacing_x, yb * self.spacing_y)
- if t is not None:
- self.trajectories.append(t)
-
- def get_streamlines(self):
- """
- Get streamlines by building trajectory set.
- """
- for indent in range(self.density // 2):
- for xi in range(self.density - 2 * indent):
- self.traj(xi + indent, indent)
- self.traj(xi + indent, self.density - 1 - indent)
- self.traj(indent, xi + indent)
- self.traj(self.density - 1 - indent, xi + indent)
-
- self.st_x = [np.array(t[0]) * self.delta_x + self.x[0] for t in
- self.trajectories]
- self.st_y = [np.array(t[1]) * self.delta_y + self.y[0] for t in
- self.trajectories]
-
- for index in range(len(self.st_x)):
- self.st_x[index] = self.st_x[index].tolist()
- self.st_x[index].append(np.nan)
-
- for index in range(len(self.st_y)):
- self.st_y[index] = self.st_y[index].tolist()
- self.st_y[index].append(np.nan)
-
- def get_streamline_arrows(self):
- """
- Makes an arrow for each streamline.
-
- Gets angle of streamline at 1/3 mark and creates arrow coordinates
- based off of user defined angle and arrow_scale.
-
- :param (array) st_x: x-values for all streamlines
- :param (array) st_y: y-values for all streamlines
- :param (angle in radians) angle: angle of arrowhead. Default = pi/9
- :param (float in [0,1]) arrow_scale: value to scale length of arrowhead
- Default = .09
- :rtype (list, list) arrows_x: x-values to create arrowhead and
- arrows_y: y-values to create arrowhead
- """
- arrow_end_x = np.empty((len(self.st_x)))
- arrow_end_y = np.empty((len(self.st_y)))
- arrow_start_x = np.empty((len(self.st_x)))
- arrow_start_y = np.empty((len(self.st_y)))
- for index in range(len(self.st_x)):
- arrow_end_x[index] = (self.st_x[index]
- [int(len(self.st_x[index]) / 3)])
- arrow_start_x[index] = (self.st_x[index]
- [(int(len(self.st_x[index]) / 3)) - 1])
- arrow_end_y[index] = (self.st_y[index]
- [int(len(self.st_y[index]) / 3)])
- arrow_start_y[index] = (self.st_y[index]
- [(int(len(self.st_y[index]) / 3)) - 1])
-
- dif_x = arrow_end_x - arrow_start_x
- dif_y = arrow_end_y - arrow_start_y
-
- streamline_ang = np.arctan(dif_y / dif_x)
-
- ang1 = streamline_ang + (self.angle)
- ang2 = streamline_ang - (self.angle)
-
- seg1_x = np.cos(ang1) * self.arrow_scale
- seg1_y = np.sin(ang1) * self.arrow_scale
- seg2_x = np.cos(ang2) * self.arrow_scale
- seg2_y = np.sin(ang2) * self.arrow_scale
-
- point1_x = np.empty((len(dif_x)))
- point1_y = np.empty((len(dif_y)))
- point2_x = np.empty((len(dif_x)))
- point2_y = np.empty((len(dif_y)))
-
- for index in range(len(dif_x)):
- if dif_x[index] >= 0:
- point1_x[index] = arrow_end_x[index] - seg1_x[index]
- point1_y[index] = arrow_end_y[index] - seg1_y[index]
- point2_x[index] = arrow_end_x[index] - seg2_x[index]
- point2_y[index] = arrow_end_y[index] - seg2_y[index]
- else:
- point1_x[index] = arrow_end_x[index] + seg1_x[index]
- point1_y[index] = arrow_end_y[index] + seg1_y[index]
- point2_x[index] = arrow_end_x[index] + seg2_x[index]
- point2_y[index] = arrow_end_y[index] + seg2_y[index]
-
- space = np.empty((len(point1_x)))
- space[:] = np.nan
-
- # Combine arrays into matrix
- arrows_x = np.matrix([point1_x, arrow_end_x, point2_x, space])
- arrows_x = np.array(arrows_x)
- arrows_x = arrows_x.flatten('F')
- arrows_x = arrows_x.tolist()
-
- # Combine arrays into matrix
- arrows_y = np.matrix([point1_y, arrow_end_y, point2_y, space])
- arrows_y = np.array(arrows_y)
- arrows_y = arrows_y.flatten('F')
- arrows_y = arrows_y.tolist()
-
- return arrows_x, arrows_y
-
- def sum_streamlines(self):
- """
- Makes all streamlines readable as a single trace.
-
- :rtype (list, list): streamline_x: all x values for each streamline
- combined into single list and streamline_y: all y values for each
- streamline combined into single list
- """
- streamline_x = sum(self.st_x, [])
- streamline_y = sum(self.st_y, [])
- return streamline_x, streamline_y
-
-
-class _OHLC(FigureFactory):
- """
- Refer to FigureFactory.create_ohlc_increase() for docstring.
- """
- def __init__(self, open, high, low, close, dates, **kwargs):
- self.open = open
- self.high = high
- self.low = low
- self.close = close
- self.empty = [None] * len(open)
- self.dates = dates
-
- self.all_x = []
- self.all_y = []
- self.increase_x = []
- self.increase_y = []
- self.decrease_x = []
- self.decrease_y = []
- self.get_all_xy()
- self.separate_increase_decrease()
-
- def get_all_xy(self):
- """
- Zip data to create OHLC shape
-
- OHLC shape: low to high vertical bar with
- horizontal branches for open and close values.
- If dates were added, the smallest date difference is calculated and
- multiplied by .2 to get the length of the open and close branches.
- If no date data was provided, the x-axis is a list of integers and the
- length of the open and close branches is .2.
- """
- self.all_y = list(zip(self.open, self.open, self.high,
- self.low, self.close, self.close, self.empty))
- if self.dates is not None:
- date_dif = []
- for i in range(len(self.dates) - 1):
- date_dif.append(self.dates[i + 1] - self.dates[i])
- date_dif_min = (min(date_dif)) / 5
- self.all_x = [[x - date_dif_min, x, x, x, x, x +
- date_dif_min, None] for x in self.dates]
- else:
- self.all_x = [[x - .2, x, x, x, x, x + .2, None]
- for x in range(len(self.open))]
-
- def separate_increase_decrease(self):
- """
- Separate data into two groups: increase and decrease
-
- (1) Increase, where close > open and
- (2) Decrease, where close <= open
- """
- for index in range(len(self.open)):
- if self.close[index] is None:
- pass
- elif self.close[index] > self.open[index]:
- self.increase_x.append(self.all_x[index])
- self.increase_y.append(self.all_y[index])
- else:
- self.decrease_x.append(self.all_x[index])
- self.decrease_y.append(self.all_y[index])
-
- def get_increase(self):
- """
- Flatten increase data and get increase text
-
- :rtype (list, list, list): flat_increase_x: x-values for the increasing
- trace, flat_increase_y: y=values for the increasing trace and
- text_increase: hovertext for the increasing trace
- """
- flat_increase_x = FigureFactory._flatten(self.increase_x)
- flat_increase_y = FigureFactory._flatten(self.increase_y)
- text_increase = (("Open", "Open", "High",
- "Low", "Close", "Close", '')
- * (len(self.increase_x)))
-
- return flat_increase_x, flat_increase_y, text_increase
-
- def get_decrease(self):
- """
- Flatten decrease data and get decrease text
-
- :rtype (list, list, list): flat_decrease_x: x-values for the decreasing
- trace, flat_decrease_y: y=values for the decreasing trace and
- text_decrease: hovertext for the decreasing trace
- """
- flat_decrease_x = FigureFactory._flatten(self.decrease_x)
- flat_decrease_y = FigureFactory._flatten(self.decrease_y)
- text_decrease = (("Open", "Open", "High",
- "Low", "Close", "Close", '')
- * (len(self.decrease_x)))
-
- return flat_decrease_x, flat_decrease_y, text_decrease
-
-
-class _Candlestick(FigureFactory):
- """
- Refer to FigureFactory.create_candlestick() for docstring.
- """
- def __init__(self, open, high, low, close, dates, **kwargs):
- self.open = open
- self.high = high
- self.low = low
- self.close = close
- if dates is not None:
- self.x = dates
- else:
- self.x = [x for x in range(len(self.open))]
- self.get_candle_increase()
-
- def get_candle_increase(self):
- """
- Separate increasing data from decreasing data.
-
- The data is increasing when close value > open value
- and decreasing when the close value <= open value.
- """
- increase_y = []
- increase_x = []
- for index in range(len(self.open)):
- if self.close[index] > self.open[index]:
- increase_y.append(self.low[index])
- increase_y.append(self.open[index])
- increase_y.append(self.close[index])
- increase_y.append(self.close[index])
- increase_y.append(self.close[index])
- increase_y.append(self.high[index])
- increase_x.append(self.x[index])
-
- increase_x = [[x, x, x, x, x, x] for x in increase_x]
- increase_x = FigureFactory._flatten(increase_x)
-
- return increase_x, increase_y
-
- def get_candle_decrease(self):
- """
- Separate increasing data from decreasing data.
-
- The data is increasing when close value > open value
- and decreasing when the close value <= open value.
- """
- decrease_y = []
- decrease_x = []
- for index in range(len(self.open)):
- if self.close[index] <= self.open[index]:
- decrease_y.append(self.low[index])
- decrease_y.append(self.open[index])
- decrease_y.append(self.close[index])
- decrease_y.append(self.close[index])
- decrease_y.append(self.close[index])
- decrease_y.append(self.high[index])
- decrease_x.append(self.x[index])
-
- decrease_x = [[x, x, x, x, x, x] for x in decrease_x]
- decrease_x = FigureFactory._flatten(decrease_x)
-
- return decrease_x, decrease_y
-
-
-class _Distplot(FigureFactory):
- """
- Refer to TraceFactory.create_distplot() for docstring
- """
- def __init__(self, hist_data, histnorm, group_labels,
- bin_size, curve_type, colors,
- rug_text, show_hist, show_curve):
- self.hist_data = hist_data
- self.histnorm = histnorm
- self.group_labels = group_labels
- self.bin_size = bin_size
- self.show_hist = show_hist
- self.show_curve = show_curve
- self.trace_number = len(hist_data)
- if rug_text:
- self.rug_text = rug_text
- else:
- self.rug_text = [None] * self.trace_number
-
- self.start = []
- self.end = []
- if colors:
- self.colors = colors
- else:
- self.colors = [
- "rgb(31, 119, 180)", "rgb(255, 127, 14)",
- "rgb(44, 160, 44)", "rgb(214, 39, 40)",
- "rgb(148, 103, 189)", "rgb(140, 86, 75)",
- "rgb(227, 119, 194)", "rgb(127, 127, 127)",
- "rgb(188, 189, 34)", "rgb(23, 190, 207)"]
- self.curve_x = [None] * self.trace_number
- self.curve_y = [None] * self.trace_number
-
- for trace in self.hist_data:
- self.start.append(min(trace) * 1.)
- self.end.append(max(trace) * 1.)
-
- def make_hist(self):
- """
- Makes the histogram(s) for FigureFactory.create_distplot().
-
- :rtype (list) hist: list of histogram representations
- """
- hist = [None] * self.trace_number
-
- for index in range(self.trace_number):
- hist[index] = dict(type='histogram',
- x=self.hist_data[index],
- xaxis='x1',
- yaxis='y1',
- histnorm=self.histnorm,
- name=self.group_labels[index],
- legendgroup=self.group_labels[index],
- marker=dict(color=self.colors[index]),
- autobinx=False,
- xbins=dict(start=self.start[index],
- end=self.end[index],
- size=self.bin_size[index]),
- opacity=.7)
- return hist
-
- def make_kde(self):
- """
- Makes the kernel density estimation(s) for create_distplot().
-
- This is called when curve_type = 'kde' in create_distplot().
-
- :rtype (list) curve: list of kde representations
- """
- curve = [None] * self.trace_number
- for index in range(self.trace_number):
- self.curve_x[index] = [self.start[index] +
- x * (self.end[index] - self.start[index])
- / 500 for x in range(500)]
- self.curve_y[index] = (scipy.stats.gaussian_kde
- (self.hist_data[index])
- (self.curve_x[index]))
-
- if self.histnorm == ALTERNATIVE_HISTNORM:
- self.curve_y[index] *= self.bin_size[index]
-
- for index in range(self.trace_number):
- curve[index] = dict(type='scatter',
- x=self.curve_x[index],
- y=self.curve_y[index],
- xaxis='x1',
- yaxis='y1',
- mode='lines',
- name=self.group_labels[index],
- legendgroup=self.group_labels[index],
- showlegend=False if self.show_hist else True,
- marker=dict(color=self.colors[index]))
- return curve
-
- def make_normal(self):
- """
- Makes the normal curve(s) for create_distplot().
-
- This is called when curve_type = 'normal' in create_distplot().
-
- :rtype (list) curve: list of normal curve representations
- """
- curve = [None] * self.trace_number
- mean = [None] * self.trace_number
- sd = [None] * self.trace_number
-
- for index in range(self.trace_number):
- mean[index], sd[index] = (scipy.stats.norm.fit
- (self.hist_data[index]))
- self.curve_x[index] = [self.start[index] +
- x * (self.end[index] - self.start[index])
- / 500 for x in range(500)]
- self.curve_y[index] = scipy.stats.norm.pdf(
- self.curve_x[index], loc=mean[index], scale=sd[index])
-
- if self.histnorm == ALTERNATIVE_HISTNORM:
- self.curve_y[index] *= self.bin_size[index]
-
- for index in range(self.trace_number):
- curve[index] = dict(type='scatter',
- x=self.curve_x[index],
- y=self.curve_y[index],
- xaxis='x1',
- yaxis='y1',
- mode='lines',
- name=self.group_labels[index],
- legendgroup=self.group_labels[index],
- showlegend=False if self.show_hist else True,
- marker=dict(color=self.colors[index]))
- return curve
-
- def make_rug(self):
- """
- Makes the rug plot(s) for create_distplot().
-
- :rtype (list) rug: list of rug plot representations
- """
- rug = [None] * self.trace_number
- for index in range(self.trace_number):
-
- rug[index] = dict(type='scatter',
- x=self.hist_data[index],
- y=([self.group_labels[index]] *
- len(self.hist_data[index])),
- xaxis='x1',
- yaxis='y2',
- mode='markers',
- name=self.group_labels[index],
- legendgroup=self.group_labels[index],
- showlegend=(False if self.show_hist or
- self.show_curve else True),
- text=self.rug_text[index],
- marker=dict(color=self.colors[index],
- symbol='line-ns-open'))
- return rug
-
-
-class _Dendrogram(FigureFactory):
- """Refer to FigureFactory.create_dendrogram() for docstring."""
-
- def __init__(self, X, orientation='bottom', labels=None, colorscale=None,
- width="100%", height="100%", xaxis='xaxis', yaxis='yaxis',
- distfun=None, linkagefun=lambda x: sch.linkage(x, 'complete')):
- # TODO: protected until #282
- from plotly.graph_objs import graph_objs
- self.orientation = orientation
- self.labels = labels
- self.xaxis = xaxis
- self.yaxis = yaxis
- self.data = []
- self.leaves = []
- self.sign = {self.xaxis: 1, self.yaxis: 1}
- self.layout = {self.xaxis: {}, self.yaxis: {}}
-
- if self.orientation in ['left', 'bottom']:
- self.sign[self.xaxis] = 1
- else:
- self.sign[self.xaxis] = -1
-
- if self.orientation in ['right', 'bottom']:
- self.sign[self.yaxis] = 1
- else:
- self.sign[self.yaxis] = -1
-
- if distfun is None:
- distfun = scs.distance.pdist
-
- (dd_traces, xvals, yvals,
- ordered_labels, leaves) = self.get_dendrogram_traces(X, colorscale, distfun, linkagefun)
-
- self.labels = ordered_labels
- self.leaves = leaves
- yvals_flat = yvals.flatten()
- xvals_flat = xvals.flatten()
-
- self.zero_vals = []
-
- for i in range(len(yvals_flat)):
- if yvals_flat[i] == 0.0 and xvals_flat[i] not in self.zero_vals:
- self.zero_vals.append(xvals_flat[i])
-
- self.zero_vals.sort()
-
- self.layout = self.set_figure_layout(width, height)
- self.data = graph_objs.Data(dd_traces)
-
- def get_color_dict(self, colorscale):
- """
- Returns colorscale used for dendrogram tree clusters.
-
- :param (list) colorscale: Colors to use for the plot in rgb format.
- :rtype (dict): A dict of default colors mapped to the user colorscale.
-
- """
-
- # These are the color codes returned for dendrograms
- # We're replacing them with nicer colors
- d = {'r': 'red',
- 'g': 'green',
- 'b': 'blue',
- 'c': 'cyan',
- 'm': 'magenta',
- 'y': 'yellow',
- 'k': 'black',
- 'w': 'white'}
- default_colors = OrderedDict(sorted(d.items(), key=lambda t: t[0]))
-
- if colorscale is None:
- colorscale = [
- 'rgb(0,116,217)', # blue
- 'rgb(35,205,205)', # cyan
- 'rgb(61,153,112)', # green
- 'rgb(40,35,35)', # black
- 'rgb(133,20,75)', # magenta
- 'rgb(255,65,54)', # red
- 'rgb(255,255,255)', # white
- 'rgb(255,220,0)'] # yellow
-
- for i in range(len(default_colors.keys())):
- k = list(default_colors.keys())[i] # PY3 won't index keys
- if i < len(colorscale):
- default_colors[k] = colorscale[i]
-
- return default_colors
-
- def set_axis_layout(self, axis_key):
- """
- Sets and returns default axis object for dendrogram figure.
-
- :param (str) axis_key: E.g., 'xaxis', 'xaxis1', 'yaxis', yaxis1', etc.
- :rtype (dict): An axis_key dictionary with set parameters.
-
- """
- axis_defaults = {
- 'type': 'linear',
- 'ticks': 'outside',
- 'mirror': 'allticks',
- 'rangemode': 'tozero',
- 'showticklabels': True,
- 'zeroline': False,
- 'showgrid': False,
- 'showline': True,
- }
-
- if len(self.labels) != 0:
- axis_key_labels = self.xaxis
- if self.orientation in ['left', 'right']:
- axis_key_labels = self.yaxis
- if axis_key_labels not in self.layout:
- self.layout[axis_key_labels] = {}
- self.layout[axis_key_labels]['tickvals'] = \
- [zv*self.sign[axis_key] for zv in self.zero_vals]
- self.layout[axis_key_labels]['ticktext'] = self.labels
- self.layout[axis_key_labels]['tickmode'] = 'array'
-
- self.layout[axis_key].update(axis_defaults)
-
- return self.layout[axis_key]
-
- def set_figure_layout(self, width, height):
- """
- Sets and returns default layout object for dendrogram figure.
-
- """
- self.layout.update({
- 'showlegend': False,
- 'autosize': False,
- 'hovermode': 'closest',
- 'width': width,
- 'height': height
- })
-
- self.set_axis_layout(self.xaxis)
- self.set_axis_layout(self.yaxis)
-
- return self.layout
-
- def get_dendrogram_traces(self, X, colorscale, distfun, linkagefun):
- """
- Calculates all the elements needed for plotting a dendrogram.
-
- :param (ndarray) X: Matrix of observations as array of arrays
- :param (list) colorscale: Color scale for dendrogram tree clusters
- :param (function) distfun: Function to compute the pairwise distance from the observations
- :param (function) linkagefun: Function to compute the linkage matrix from the pairwise distances
- :rtype (tuple): Contains all the traces in the following order:
- (a) trace_list: List of Plotly trace objects for dendrogram tree
- (b) icoord: All X points of the dendrogram tree as array of arrays
- with length 4
- (c) dcoord: All Y points of the dendrogram tree as array of arrays
- with length 4
- (d) ordered_labels: leaf labels in the order they are going to
- appear on the plot
- (e) P['leaves']: left-to-right traversal of the leaves
-
- """
- # TODO: protected until #282
- from plotly.graph_objs import graph_objs
- d = distfun(X)
- Z = linkagefun(d)
- P = sch.dendrogram(Z, orientation=self.orientation,
- labels=self.labels, no_plot=True)
-
- icoord = scp.array(P['icoord'])
- dcoord = scp.array(P['dcoord'])
- ordered_labels = scp.array(P['ivl'])
- color_list = scp.array(P['color_list'])
- colors = self.get_color_dict(colorscale)
-
- trace_list = []
-
- for i in range(len(icoord)):
- # xs and ys are arrays of 4 points that make up the '∩' shapes
- # of the dendrogram tree
- if self.orientation in ['top', 'bottom']:
- xs = icoord[i]
- else:
- xs = dcoord[i]
-
- if self.orientation in ['top', 'bottom']:
- ys = dcoord[i]
- else:
- ys = icoord[i]
- color_key = color_list[i]
- trace = graph_objs.Scatter(
- x=np.multiply(self.sign[self.xaxis], xs),
- y=np.multiply(self.sign[self.yaxis], ys),
- mode='lines',
- marker=graph_objs.Marker(color=colors[color_key])
- )
-
- try:
- x_index = int(self.xaxis[-1])
- except ValueError:
- x_index = ''
-
- try:
- y_index = int(self.yaxis[-1])
- except ValueError:
- y_index = ''
-
- trace['xaxis'] = 'x' + x_index
- trace['yaxis'] = 'y' + y_index
-
- trace_list.append(trace)
-
- return trace_list, icoord, dcoord, ordered_labels, P['leaves']
-
-
-class _AnnotatedHeatmap(FigureFactory):
- """
- Refer to TraceFactory.create_annotated_heatmap() for docstring
- """
- def __init__(self, z, x, y, annotation_text, colorscale,
- font_colors, reversescale, **kwargs):
- from plotly.graph_objs import graph_objs
-
- self.z = z
- if x:
- self.x = x
- else:
- self.x = range(len(z[0]))
- if y:
- self.y = y
- else:
- self.y = range(len(z))
- if annotation_text is not None:
- self.annotation_text = annotation_text
- else:
- self.annotation_text = self.z
- self.colorscale = colorscale
- self.reversescale = reversescale
- self.font_colors = font_colors
-
- def get_text_color(self):
- """
- Get font color for annotations.
-
- The annotated heatmap can feature two text colors: min_text_color and
- max_text_color. The min_text_color is applied to annotations for
- heatmap values < (max_value - min_value)/2. The user can define these
- two colors. Otherwise the colors are defined logically as black or
- white depending on the heatmap's colorscale.
-
- :rtype (string, string) min_text_color, max_text_color: text
- color for annotations for heatmap values <
- (max_value - min_value)/2 and text color for annotations for
- heatmap values >= (max_value - min_value)/2
- """
- # Plotly colorscales ranging from a lighter shade to a darker shade
- colorscales = ['Greys', 'Greens', 'Blues',
- 'YIGnBu', 'YIOrRd', 'RdBu',
- 'Picnic', 'Jet', 'Hot', 'Blackbody',
- 'Earth', 'Electric', 'Viridis']
- # Plotly colorscales ranging from a darker shade to a lighter shade
- colorscales_reverse = ['Reds']
- if self.font_colors:
- min_text_color = self.font_colors[0]
- max_text_color = self.font_colors[-1]
- elif self.colorscale in colorscales and self.reversescale:
- min_text_color = '#000000'
- max_text_color = '#FFFFFF'
- elif self.colorscale in colorscales:
- min_text_color = '#FFFFFF'
- max_text_color = '#000000'
- elif self.colorscale in colorscales_reverse and self.reversescale:
- min_text_color = '#FFFFFF'
- max_text_color = '#000000'
- elif self.colorscale in colorscales_reverse:
- min_text_color = '#000000'
- max_text_color = '#FFFFFF'
- elif isinstance(self.colorscale, list):
- if 'rgb' in self.colorscale[0][1]:
- min_col = map(int,
- self.colorscale[0][1].strip('rgb()').split(','))
- max_col = map(int,
- self.colorscale[-1][1].strip('rgb()').split(','))
- elif '#' in self.colorscale[0][1]:
- min_col = FigureFactory._hex_to_rgb(self.colorscale[0][1])
- max_col = FigureFactory._hex_to_rgb(self.colorscale[-1][1])
- else:
- min_col = [255, 255, 255]
- max_col = [255, 255, 255]
-
- if (min_col[0]*0.299 + min_col[1]*0.587 + min_col[2]*0.114) > 186:
- min_text_color = '#000000'
- else:
- min_text_color = '#FFFFFF'
- if (max_col[0]*0.299 + max_col[1]*0.587 + max_col[2]*0.114) > 186:
- max_text_color = '#000000'
- else:
- max_text_color = '#FFFFFF'
- else:
- min_text_color = '#000000'
- max_text_color = '#000000'
- return min_text_color, max_text_color
-
- def get_z_mid(self):
- """
- Get the mid value of z matrix
-
- :rtype (float) z_avg: average val from z matrix
- """
- if _numpy_imported and isinstance(self.z, np.ndarray):
- z_min = np.amin(self.z)
- z_max = np.amax(self.z)
- else:
- z_min = min(min(self.z))
- z_max = max(max(self.z))
- z_mid = (z_max+z_min) / 2
- return z_mid
-
- def make_annotations(self):
- """
- Get annotations for each cell of the heatmap with graph_objs.Annotation
-
- :rtype (list[dict]) annotations: list of annotations for each cell of
- the heatmap
- """
- from plotly.graph_objs import graph_objs
- min_text_color, max_text_color = _AnnotatedHeatmap.get_text_color(self)
- z_mid = _AnnotatedHeatmap.get_z_mid(self)
- annotations = []
- for n, row in enumerate(self.z):
- for m, val in enumerate(row):
- font_color = min_text_color if val < z_mid else max_text_color
- annotations.append(
- graph_objs.Annotation(
- text=str(self.annotation_text[n][m]),
- x=self.x[m],
- y=self.y[n],
- xref='x1',
- yref='y1',
- font=dict(color=font_color),
- showarrow=False))
- return annotations
-
-
-class _Table(FigureFactory):
- """
- Refer to TraceFactory.create_table() for docstring
- """
- def __init__(self, table_text, colorscale, font_colors, index,
- index_title, annotation_offset, **kwargs):
- from plotly.graph_objs import graph_objs
- if _pandas_imported and isinstance(table_text, pd.DataFrame):
- headers = table_text.columns.tolist()
- table_text_index = table_text.index.tolist()
- table_text = table_text.values.tolist()
- table_text.insert(0, headers)
- if index:
- table_text_index.insert(0, index_title)
- for i in range(len(table_text)):
- table_text[i].insert(0, table_text_index[i])
- self.table_text = table_text
- self.colorscale = colorscale
- self.font_colors = font_colors
- self.index = index
- self.annotation_offset = annotation_offset
- self.x = range(len(table_text[0]))
- self.y = range(len(table_text))
-
- def get_table_matrix(self):
- """
- Create z matrix to make heatmap with striped table coloring
-
- :rtype (list[list]) table_matrix: z matrix to make heatmap with striped
- table coloring.
- """
- header = [0] * len(self.table_text[0])
- odd_row = [.5] * len(self.table_text[0])
- even_row = [1] * len(self.table_text[0])
- table_matrix = [None] * len(self.table_text)
- table_matrix[0] = header
- for i in range(1, len(self.table_text), 2):
- table_matrix[i] = odd_row
- for i in range(2, len(self.table_text), 2):
- table_matrix[i] = even_row
- if self.index:
- for array in table_matrix:
- array[0] = 0
- return table_matrix
-
- def get_table_font_color(self):
- """
- Fill font-color array.
-
- Table text color can vary by row so this extends a single color or
- creates an array to set a header color and two alternating colors to
- create the striped table pattern.
-
- :rtype (list[list]) all_font_colors: list of font colors for each row
- in table.
- """
- if len(self.font_colors) == 1:
- all_font_colors = self.font_colors*len(self.table_text)
- elif len(self.font_colors) == 3:
- all_font_colors = list(range(len(self.table_text)))
- all_font_colors[0] = self.font_colors[0]
- for i in range(1, len(self.table_text), 2):
- all_font_colors[i] = self.font_colors[1]
- for i in range(2, len(self.table_text), 2):
- all_font_colors[i] = self.font_colors[2]
- elif len(self.font_colors) == len(self.table_text):
- all_font_colors = self.font_colors
- else:
- all_font_colors = ['#000000']*len(self.table_text)
- return all_font_colors
-
- def make_table_annotations(self):
- """
- Generate annotations to fill in table text
-
- :rtype (list) annotations: list of annotations for each cell of the
- table.
- """
- from plotly.graph_objs import graph_objs
- table_matrix = _Table.get_table_matrix(self)
- all_font_colors = _Table.get_table_font_color(self)
- annotations = []
- for n, row in enumerate(self.table_text):
- for m, val in enumerate(row):
- # Bold text in header and index
- format_text = ('' + str(val) + '' if n == 0 or
- self.index and m < 1 else str(val))
- # Match font color of index to font color of header
- font_color = (self.font_colors[0] if self.index and m == 0
- else all_font_colors[n])
- annotations.append(
- graph_objs.Annotation(
- text=format_text,
- x=self.x[m] - self.annotation_offset,
- y=self.y[n],
- xref='x1',
- yref='y1',
- align="left",
- xanchor="left",
- font=dict(color=font_color),
- showarrow=False))
- return annotations
+ def create_violin(*args, **kwargs):
+ FigureFactory._deprecated('create_violin')
+ from plotly.figure_factory import create_violin
+ return create_violin(*args, **kwargs)
diff --git a/plotly/utils.py b/plotly/utils.py
index 782779c457c..37fdd3a26bb 100644
--- a/plotly/utils.py
+++ b/plotly/utils.py
@@ -7,7 +7,6 @@
"""
from __future__ import absolute_import
-import json
import os.path
import re
import sys
@@ -15,27 +14,16 @@
import decimal
import pytz
+from requests.compat import json as _json
+from plotly.optional_imports import get_module
from . exceptions import PlotlyError
-try:
- import numpy
- _numpy_imported = True
-except ImportError:
- _numpy_imported = False
-
-try:
- import pandas
- _pandas_imported = True
-except ImportError:
- _pandas_imported = False
-
-try:
- import sage.all
- _sage_imported = True
-except ImportError:
- _sage_imported = False
+# Optional imports, may be None for users that only use our core functionality.
+numpy = get_module('numpy')
+pandas = get_module('pandas')
+sage_all = get_module('sage.all')
### incase people are using threading, we lock file reads
@@ -51,7 +39,7 @@ def load_json_dict(filename, *args):
lock.acquire()
with open(filename, "r") as f:
try:
- data = json.load(f)
+ data = _json.load(f)
if not isinstance(data, dict):
data = {}
except:
@@ -66,7 +54,7 @@ def save_json_dict(filename, json_dict):
"""Save json to file. Error if path DNE, not a dict, or invalid json."""
if isinstance(json_dict, dict):
# this will raise a TypeError if something goes wrong
- json_string = json.dumps(json_dict, indent=4)
+ json_string = _json.dumps(json_dict, indent=4)
lock.acquire()
with open(filename, "w") as f:
f.write(json_string)
@@ -112,7 +100,7 @@ class NotEncodable(Exception):
pass
-class PlotlyJSONEncoder(json.JSONEncoder):
+class PlotlyJSONEncoder(_json.JSONEncoder):
"""
Meant to be passed as the `cls` kwarg to json.dumps(obj, cls=..)
@@ -149,7 +137,8 @@ def encode(self, o):
# 1. `loads` to switch Infinity, -Infinity, NaN to None
# 2. `dumps` again so you get 'null' instead of extended JSON
try:
- new_o = json.loads(encoded_o, parse_constant=self.coerce_to_strict)
+ new_o = _json.loads(encoded_o,
+ parse_constant=self.coerce_to_strict)
except ValueError:
# invalid separators will fail here. raise a helpful exception
@@ -158,10 +147,10 @@ def encode(self, o):
"valid JSON separators?"
)
else:
- return json.dumps(new_o, sort_keys=self.sort_keys,
- indent=self.indent,
- separators=(self.item_separator,
- self.key_separator))
+ return _json.dumps(new_o, sort_keys=self.sort_keys,
+ indent=self.indent,
+ separators=(self.item_separator,
+ self.key_separator))
def default(self, obj):
"""
@@ -210,7 +199,7 @@ def default(self, obj):
return encoding_method(obj)
except NotEncodable:
pass
- return json.JSONEncoder.default(self, obj)
+ return _json.JSONEncoder.default(self, obj)
@staticmethod
def encode_as_plotly(obj):
@@ -231,12 +220,12 @@ def encode_as_list(obj):
@staticmethod
def encode_as_sage(obj):
"""Attempt to convert sage.all.RR to floats and sage.all.ZZ to ints"""
- if not _sage_imported:
+ if not sage_all:
raise NotEncodable
- if obj in sage.all.RR:
+ if obj in sage_all.RR:
return float(obj)
- elif obj in sage.all.ZZ:
+ elif obj in sage_all.ZZ:
return int(obj)
else:
raise NotEncodable
@@ -244,7 +233,7 @@ def encode_as_sage(obj):
@staticmethod
def encode_as_pandas(obj):
"""Attempt to convert pandas.NaT"""
- if not _pandas_imported:
+ if not pandas:
raise NotEncodable
if obj is pandas.NaT:
@@ -255,7 +244,7 @@ def encode_as_pandas(obj):
@staticmethod
def encode_as_numpy(obj):
"""Attempt to convert numpy.ma.core.masked"""
- if not _numpy_imported:
+ if not numpy:
raise NotEncodable
if obj is numpy.ma.core.masked:
diff --git a/plotly/version.py b/plotly/version.py
index 84c54b74824..afced14728f 100644
--- a/plotly/version.py
+++ b/plotly/version.py
@@ -1 +1 @@
-__version__ = '1.13.0'
+__version__ = '2.0.0'
diff --git a/plotly/widgets/graph_widget.py b/plotly/widgets/graph_widget.py
index a359474a7d0..f7eb0a86084 100644
--- a/plotly/widgets/graph_widget.py
+++ b/plotly/widgets/graph_widget.py
@@ -2,11 +2,11 @@
Module to allow Plotly graphs to interact with IPython widgets.
"""
-import json
import uuid
from collections import deque
from pkg_resources import resource_string
+from requests.compat import json as _json
# TODO: protected imports?
from IPython.html import widgets
@@ -93,7 +93,7 @@ def _handle_msg(self, message):
while self._clientMessages:
_message = self._clientMessages.popleft()
_message['graphId'] = self._graphId
- _message = json.dumps(_message)
+ _message = _json.dumps(_message)
self._message = _message
if content.get('event', '') in ['click', 'hover', 'zoom']:
@@ -131,7 +131,7 @@ def _handle_outgoing_message(self, message):
else:
message['graphId'] = self._graphId
message['uid'] = str(uuid.uuid4())
- self._message = json.dumps(message, cls=utils.PlotlyJSONEncoder)
+ self._message = _json.dumps(message, cls=utils.PlotlyJSONEncoder)
def on_click(self, callback, remove=False):
""" Assign a callback to click events propagated
diff --git a/setup.py b/setup.py
index 9c50bcbf475..1e66057b516 100644
--- a/setup.py
+++ b/setup.py
@@ -31,8 +31,12 @@ def readme():
],
license='MIT',
packages=['plotly',
+ 'plotly/api',
+ 'plotly/api/v1',
+ 'plotly/api/v2',
'plotly/plotly',
'plotly/plotly/chunked_requests',
+ 'plotly/figure_factory',
'plotly/graph_objs',
'plotly/grid_objs',
'plotly/widgets',