diff --git a/CHANGELOG.md b/CHANGELOG.md index 985f3c231d0..def553f9704 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ This project adheres to [Semantic Versioning](http://semver.org/). - Fixed special cases with `px.sunburst` and `px.treemap` with `path` input ([#2524](https://github.com/plotly/plotly.py/pull/2524)) +### Added + +- New hexbin_mapbox trace in Plotly Express + ## [4.8.1] - 2020-05-28 ### Fixed diff --git a/doc/python/figure-factories.md b/doc/python/figure-factories.md index 881185e332a..27fd4896aa0 100644 --- a/doc/python/figure-factories.md +++ b/doc/python/figure-factories.md @@ -42,6 +42,7 @@ The following types of plots are still difficult to create with Graph Objects or * [Annotated Heatmaps](/python/annotated-heatmap/) * [Dendrograms](/python/dendrogram/) * [Gantt Charts](/python/gantt/) + * [Hexagonal Binning Mapbox](/python/hexbin-mapbox/) * [Quiver Plots](/python/quiver-plots/) * [Streamline Plots](/python/streamline-plots/) * [Tables](/python/figure-factory-table/) diff --git a/doc/python/hexbin-mapbox.md b/doc/python/hexbin-mapbox.md new file mode 100644 index 00000000000..4f010ac224a --- /dev/null +++ b/doc/python/hexbin-mapbox.md @@ -0,0 +1,165 @@ +--- +jupyter: + jupytext: + notebook_metadata_filter: all + text_representation: + extension: .md + format_name: markdown + format_version: '1.2' + jupytext_version: 1.5.1 + kernelspec: + display_name: Python 3 + language: python + name: python3 + language_info: + codemirror_mode: + name: ipython + version: 3 + file_extension: .py + mimetype: text/x-python + name: python + nbconvert_exporter: python + pygments_lexer: ipython3 + version: 3.7.4 + plotly: + description: How to make a map with Hexagonal Binning of data in Python with Plotly. + display_as: scientific + language: python + layout: base + name: Hexbin Mapbox + order: 7 + page_type: u-guide + permalink: python/hexbin-mapbox/ + redirect_from: python/hexbin-mapbox/ + thumbnail: thumbnail/hexbin_mapbox.jpg +--- + +#### Simple Count Hexbin + +This page details the use of a [figure factory](/python/figure-factories/). For more examples with Choropleth maps, see [this page](/python/choropleth-maps/). + +In order to use mapbox styles that require a mapbox token, set the token with `plotly.express`. You can also use styles that do not require a mapbox token. See more information on [this page](/python/mapbox-layers/). + +```python +import plotly.figure_factory as ff +import plotly.express as px + +px.set_mapbox_access_token(open(".mapbox_token").read()) +df = px.data.carshare() + +fig = ff.create_hexbin_mapbox( + data_frame=df, lat="centroid_lat", lon="centroid_lon", + nx_hexagon=10, opacity=0.9, labels={"color": "Point Count"}, +) +fig.update_layout(margin=dict(b=0, t=0, l=0, r=0)) +fig.show() +``` + +#### Count Hexbin with Minimum Count + +```python +import plotly.figure_factory as ff +import plotly.express as px + +px.set_mapbox_access_token(open(".mapbox_token").read()) +df = px.data.carshare() + +fig = ff.create_hexbin_mapbox( + data_frame=df, lat="centroid_lat", lon="centroid_lon", + nx_hexagon=10, opacity=0.9, labels={"color": "Point Count"}, + min_count=1, +) +fig.show() +``` + +#### Display the Underlying Data + +```python +import plotly.figure_factory as ff +import plotly.express as px + +px.set_mapbox_access_token(open(".mapbox_token").read()) +df = px.data.carshare() + +fig = ff.create_hexbin_mapbox( + data_frame=df, lat="centroid_lat", lon="centroid_lon", + nx_hexagon=10, opacity=0.9, labels={"color": "Point Count"}, + min_count=1, color_continuous_scale="Viridis", + show_original_data=True, + original_data_marker=dict(size=4, opacity=0.6, color="deeppink") +) +fig.show() +``` + +#### Compute the Mean Value per Hexbin + +```python +import plotly.figure_factory as ff +import plotly.express as px + +px.set_mapbox_access_token(open(".mapbox_token").read()) +df = px.data.carshare() + +fig = ff.create_hexbin_mapbox( + data_frame=df, lat="centroid_lat", lon="centroid_lon", + nx_hexagon=10, opacity=0.9, labels={"color": "Average Peak Hour"}, + color="peak_hour", agg_func=np.mean, color_continuous_scale="Icefire", range_color=[0,23] +) +fig.show() +``` + +#### Compute the Sum Value per Hexbin + +```python +import plotly.figure_factory as ff +import plotly.express as px + +px.set_mapbox_access_token(open(".mapbox_token").read()) +df = px.data.carshare() + +fig = ff.create_hexbin_mapbox( + data_frame=df, lat="centroid_lat", lon="centroid_lon", + nx_hexagon=10, opacity=0.9, labels={"color": "Summed Car.Hours"}, + color="car_hours", agg_func=np.sum, color_continuous_scale="Magma" +) +fig.show() +``` + +#### Hexbin with Animation + +```python +import plotly.figure_factory as ff +import plotly.express as px +import numpy as np + +px.set_mapbox_access_token(open(".mapbox_token").read()) +np.random.seed(0) + +N = 500 +n_frames = 12 +lat = np.concatenate([ + np.random.randn(N) * 0.5 + np.cos(i / n_frames * 2 * np.pi) + for i in range(n_frames) +]) +lon = np.concatenate([ + np.random.randn(N) * 0.5 + np.sin(i / n_frames * 2 * np.pi) + for i in range(n_frames) +]) +frame = np.concatenate([ + np.ones(N, int) * i for i in range(n_frames) +]) + +fig = ff.create_hexbin_mapbox( + lat=lat, lon=lon, nx_hexagon=15, animation_frame=frame, + color_continuous_scale="Cividis", labels={"color": "Point Count", "frame": "Period"}, + show_original_data=True, original_data_marker=dict(opacity=0.6, size=4, color="deeppink") +) +fig.update_layout(margin=dict(b=0, t=0, l=0, r=0)) +fig.layout.sliders[0].pad.t=20 +fig.layout.updatemenus[0].pad.t=40 +fig.show() +``` + +#### Reference + +For more info on Plotly maps, see: https://plotly.com/python/maps.
For more info on using colorscales with Plotly see: https://plotly.com/python/heatmap-and-contour-colorscales/
For more info on `ff.create_annotated_heatmap()`, see the [full function reference](https://plotly.com/python-api-reference/generated/plotly.figure_factory.create_hexbin_mapbox.html#plotly.figure_factory.create_hexbin_mapbox) diff --git a/packages/python/plotly/plotly/figure_factory/__init__.py b/packages/python/plotly/plotly/figure_factory/__init__.py index 3829ca2fb67..0a41dca1ba2 100644 --- a/packages/python/plotly/plotly/figure_factory/__init__.py +++ b/packages/python/plotly/plotly/figure_factory/__init__.py @@ -29,11 +29,15 @@ if optional_imports.get_module("pandas") is not None: from plotly.figure_factory._county_choropleth import create_choropleth + from plotly.figure_factory._hexbin_mapbox import create_hexbin_mapbox else: def create_choropleth(*args, **kwargs): raise ImportError("Please install pandas to use `create_choropleth`") + def create_hexbin_mapbox(*args, **kwargs): + raise ImportError("Please install pandas to use `create_hexbin_mapbox`") + if optional_imports.get_module("skimage") is not None: from plotly.figure_factory._ternary_contour import create_ternary_contour @@ -53,6 +57,7 @@ def create_ternary_contour(*args, **kwargs): "create_distplot", "create_facet_grid", "create_gantt", + "create_hexbin_mapbox", "create_ohlc", "create_quiver", "create_scatterplotmatrix", diff --git a/packages/python/plotly/plotly/figure_factory/_hexbin_mapbox.py b/packages/python/plotly/plotly/figure_factory/_hexbin_mapbox.py new file mode 100644 index 00000000000..6ee22245753 --- /dev/null +++ b/packages/python/plotly/plotly/figure_factory/_hexbin_mapbox.py @@ -0,0 +1,492 @@ +from plotly.express._core import build_dataframe +from plotly.express._doc import make_docstring +from plotly.express._chart_types import choropleth_mapbox, scatter_mapbox +import numpy as np +import pandas as pd + + +def _project_latlon_to_wgs84(lat, lon): + """ + Projects lat and lon to WGS84, used to get regular hexagons on a mapbox map + """ + x = lon * np.pi / 180 + y = np.arctanh(np.sin(lat * np.pi / 180)) + return x, y + + +def _project_wgs84_to_latlon(x, y): + """ + Projects WGS84 to lat and lon, used to get regular hexagons on a mapbox map + """ + lon = x * 180 / np.pi + lat = (2 * np.arctan(np.exp(y)) - np.pi / 2) * 180 / np.pi + return lat, lon + + +def _getBoundsZoomLevel(lon_min, lon_max, lat_min, lat_max, mapDim): + """ + Get the mapbox zoom level given bounds and a figure dimension + Source: https://stackoverflow.com/questions/6048975/google-maps-v3-how-to-calculate-the-zoom-level-for-a-given-bounds + """ + + scale = ( + 2 # adjustment to reflect MapBox base tiles are 512x512 vs. Google's 256x256 + ) + WORLD_DIM = {"height": 256 * scale, "width": 256 * scale} + ZOOM_MAX = 18 + + def latRad(lat): + sin = np.sin(lat * np.pi / 180) + radX2 = np.log((1 + sin) / (1 - sin)) / 2 + return max(min(radX2, np.pi), -np.pi) / 2 + + def zoom(mapPx, worldPx, fraction): + return 0.95 * np.log(mapPx / worldPx / fraction) / np.log(2) + + latFraction = (latRad(lat_max) - latRad(lat_min)) / np.pi + + lngDiff = lon_max - lon_min + lngFraction = ((lngDiff + 360) if lngDiff < 0 else lngDiff) / 360 + + latZoom = zoom(mapDim["height"], WORLD_DIM["height"], latFraction) + lngZoom = zoom(mapDim["width"], WORLD_DIM["width"], lngFraction) + + return min(latZoom, lngZoom, ZOOM_MAX) + + +def _compute_hexbin(x, y, x_range, y_range, color, nx, agg_func, min_count): + """ + Computes the aggregation at hexagonal bin level. + Also defines the coordinates of the hexagons for plotting. + The binning is inspired by matplotlib's implementation. + + Parameters + ---------- + x : np.ndarray + Array of x values (shape N) + y : np.ndarray + Array of y values (shape N) + x_range : np.ndarray + Min and max x (shape 2) + y_range : np.ndarray + Min and max y (shape 2) + color : np.ndarray + Metric to aggregate at hexagon level (shape N) + nx : int + Number of hexagons horizontally + agg_func : function + Numpy compatible aggregator, this function must take a one-dimensional + np.ndarray as input and output a scalar + min_count : int + Minimum number of points in the hexagon for the hexagon to be displayed + + Returns + ------- + np.ndarray + X coordinates of each hexagon (shape M x 6) + np.ndarray + Y coordinates of each hexagon (shape M x 6) + np.ndarray + Centers of the hexagons (shape M x 2) + np.ndarray + Aggregated value in each hexagon (shape M) + + """ + xmin = x_range.min() + xmax = x_range.max() + ymin = y_range.min() + ymax = y_range.max() + + # In the x-direction, the hexagons exactly cover the region from + # xmin to xmax. Need some padding to avoid roundoff errors. + padding = 1.0e-9 * (xmax - xmin) + xmin -= padding + xmax += padding + + Dx = xmax - xmin + Dy = ymax - ymin + if Dx == 0 and Dy > 0: + dx = Dy / nx + elif Dx == 0 and Dy == 0: + dx, _ = _project_latlon_to_wgs84(1, 1) + else: + dx = Dx / nx + dy = dx * np.sqrt(3) + ny = np.ceil(Dy / dy).astype(int) + + # Center the hexagons vertically since we only want regular hexagons + ymin -= (ymin + dy * ny - ymax) / 2 + + x = (x - xmin) / dx + y = (y - ymin) / dy + ix1 = np.round(x).astype(int) + iy1 = np.round(y).astype(int) + ix2 = np.floor(x).astype(int) + iy2 = np.floor(y).astype(int) + + nx1 = nx + 1 + ny1 = ny + 1 + nx2 = nx + ny2 = ny + n = nx1 * ny1 + nx2 * ny2 + + d1 = (x - ix1) ** 2 + 3.0 * (y - iy1) ** 2 + d2 = (x - ix2 - 0.5) ** 2 + 3.0 * (y - iy2 - 0.5) ** 2 + bdist = d1 < d2 + + if color is None: + lattice1 = np.zeros((nx1, ny1)) + lattice2 = np.zeros((nx2, ny2)) + c1 = (0 <= ix1) & (ix1 < nx1) & (0 <= iy1) & (iy1 < ny1) & bdist + c2 = (0 <= ix2) & (ix2 < nx2) & (0 <= iy2) & (iy2 < ny2) & ~bdist + np.add.at(lattice1, (ix1[c1], iy1[c1]), 1) + np.add.at(lattice2, (ix2[c2], iy2[c2]), 1) + if min_count is not None: + lattice1[lattice1 < min_count] = np.nan + lattice2[lattice2 < min_count] = np.nan + accum = np.concatenate([lattice1.ravel(), lattice2.ravel()]) + good_idxs = ~np.isnan(accum) + else: + if min_count is None: + min_count = 1 + + # create accumulation arrays + lattice1 = np.empty((nx1, ny1), dtype=object) + for i in range(nx1): + for j in range(ny1): + lattice1[i, j] = [] + lattice2 = np.empty((nx2, ny2), dtype=object) + for i in range(nx2): + for j in range(ny2): + lattice2[i, j] = [] + + for i in range(len(x)): + if bdist[i]: + if 0 <= ix1[i] < nx1 and 0 <= iy1[i] < ny1: + lattice1[ix1[i], iy1[i]].append(color[i]) + else: + if 0 <= ix2[i] < nx2 and 0 <= iy2[i] < ny2: + lattice2[ix2[i], iy2[i]].append(color[i]) + + for i in range(nx1): + for j in range(ny1): + vals = lattice1[i, j] + if len(vals) >= min_count: + lattice1[i, j] = agg_func(vals) + else: + lattice1[i, j] = np.nan + for i in range(nx2): + for j in range(ny2): + vals = lattice2[i, j] + if len(vals) >= min_count: + lattice2[i, j] = agg_func(vals) + else: + lattice2[i, j] = np.nan + + accum = np.hstack( + (lattice1.astype(float).ravel(), lattice2.astype(float).ravel()) + ) + good_idxs = ~np.isnan(accum) + + agreggated_value = accum[good_idxs] + + centers = np.zeros((n, 2), float) + centers[: nx1 * ny1, 0] = np.repeat(np.arange(nx1), ny1) + centers[: nx1 * ny1, 1] = np.tile(np.arange(ny1), nx1) + centers[nx1 * ny1 :, 0] = np.repeat(np.arange(nx2) + 0.5, ny2) + centers[nx1 * ny1 :, 1] = np.tile(np.arange(ny2), nx2) + 0.5 + centers[:, 0] *= dx + centers[:, 1] *= dy + centers[:, 0] += xmin + centers[:, 1] += ymin + centers = centers[good_idxs] + + # Define normalised regular hexagon coordinates + hx = [0, 0.5, 0.5, 0, -0.5, -0.5] + hy = [ + -0.5 / np.cos(np.pi / 6), + -0.5 * np.tan(np.pi / 6), + 0.5 * np.tan(np.pi / 6), + 0.5 / np.cos(np.pi / 6), + 0.5 * np.tan(np.pi / 6), + -0.5 * np.tan(np.pi / 6), + ] + + # Number of hexagons needed + m = len(centers) + + # Coordinates for all hexagonal patches + hxs = np.array([hx] * m) * dx + np.vstack(centers[:, 0]) + hys = np.array([hy] * m) * dy / np.sqrt(3) + np.vstack(centers[:, 1]) + + return hxs, hys, centers, agreggated_value + + +def _compute_wgs84_hexbin( + lat=None, + lon=None, + lat_range=None, + lon_range=None, + color=None, + nx=None, + agg_func=None, + min_count=None, +): + """ + Computes the lat-lon aggregation at hexagonal bin level. + Latitude and longitude need to be projected to WGS84 before aggregating + in order to display regular hexagons on the map. + + Parameters + ---------- + lat : np.ndarray + Array of latitudes (shape N) + lon : np.ndarray + Array of longitudes (shape N) + lat_range : np.ndarray + Min and max latitudes (shape 2) + lon_range : np.ndarray + Min and max longitudes (shape 2) + color : np.ndarray + Metric to aggregate at hexagon level (shape N) + nx : int + Number of hexagons horizontally + agg_func : function + Numpy compatible aggregator, this function must take a one-dimensional + np.ndarray as input and output a scalar + min_count : int + Minimum number of points in the hexagon for the hexagon to be displayed + + Returns + ------- + np.ndarray + Lat coordinates of each hexagon (shape M x 6) + np.ndarray + Lon coordinates of each hexagon (shape M x 6) + pd.Series + Unique id for each hexagon, to be used in the geojson data (shape M) + np.ndarray + Aggregated value in each hexagon (shape M) + + """ + # Project to WGS 84 + x, y = _project_latlon_to_wgs84(lat, lon) + + if lat_range is None: + lat_range = np.array([lat.min(), lat.max()]) + if lon_range is None: + lon_range = np.array([lon.min(), lon.max()]) + + x_range, y_range = _project_latlon_to_wgs84(lat_range, lon_range) + + hxs, hys, centers, agreggated_value = _compute_hexbin( + x, y, x_range, y_range, color, nx, agg_func, min_count + ) + + # Convert back to lat-lon + hexagons_lats, hexagons_lons = _project_wgs84_to_latlon(hxs, hys) + + # Create unique feature id based on hexagon center + centers = centers.astype(str) + hexagons_ids = pd.Series(centers[:, 0]) + "," + pd.Series(centers[:, 1]) + + return hexagons_lats, hexagons_lons, hexagons_ids, agreggated_value + + +def _hexagons_to_geojson(hexagons_lats, hexagons_lons, ids=None): + """ + Creates a geojson of hexagonal features based on the outputs of + _compute_wgs84_hexbin + """ + features = [] + if ids is None: + ids = np.arange(len(hexagons_lats)) + for lat, lon, idx in zip(hexagons_lats, hexagons_lons, ids): + points = np.array([lon, lat]).T.tolist() + points.append(points[0]) + features.append( + dict( + type="Feature", + id=idx, + geometry=dict(type="Polygon", coordinates=[points]), + ) + ) + return dict(type="FeatureCollection", features=features) + + +def create_hexbin_mapbox( + data_frame=None, + lat=None, + lon=None, + color=None, + nx_hexagon=5, + agg_func=None, + animation_frame=None, + color_discrete_sequence=None, + color_discrete_map={}, + labels={}, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + opacity=None, + zoom=None, + center=None, + mapbox_style=None, + title=None, + template=None, + width=None, + height=None, + min_count=None, + show_original_data=False, + original_data_marker=None, +): + """ + Returns a figure aggregating scattered points into connected hexagons + """ + args = build_dataframe(args=locals(), constructor=None) + + if agg_func is None: + agg_func = np.mean + + lat_range = args["data_frame"][args["lat"]].agg(["min", "max"]).values + lon_range = args["data_frame"][args["lon"]].agg(["min", "max"]).values + + hexagons_lats, hexagons_lons, hexagons_ids, count = _compute_wgs84_hexbin( + lat=args["data_frame"][args["lat"]].values, + lon=args["data_frame"][args["lon"]].values, + lat_range=lat_range, + lon_range=lon_range, + color=None, + nx=nx_hexagon, + agg_func=agg_func, + min_count=min_count, + ) + + geojson = _hexagons_to_geojson(hexagons_lats, hexagons_lons, hexagons_ids) + + if zoom is None: + if height is None and width is None: + mapDim = dict(height=450, width=450) + elif height is None and width is not None: + mapDim = dict(height=450, width=width) + elif height is not None and width is None: + mapDim = dict(height=height, width=height) + else: + mapDim = dict(height=height, width=width) + zoom = _getBoundsZoomLevel( + lon_range[0], lon_range[1], lat_range[0], lat_range[1], mapDim + ) + + if center is None: + center = dict(lat=lat_range.mean(), lon=lon_range.mean()) + + if args["animation_frame"] is not None: + groups = args["data_frame"].groupby(args["animation_frame"]).groups + else: + groups = {0: args["data_frame"].index} + + agg_data_frame_list = [] + for frame, index in groups.items(): + df = args["data_frame"].loc[index] + _, _, hexagons_ids, aggregated_value = _compute_wgs84_hexbin( + lat=df[args["lat"]].values, + lon=df[args["lon"]].values, + lat_range=lat_range, + lon_range=lon_range, + color=df[args["color"]].values if args["color"] else None, + nx=nx_hexagon, + agg_func=agg_func, + min_count=min_count, + ) + agg_data_frame_list.append( + pd.DataFrame( + np.c_[hexagons_ids, aggregated_value], columns=["locations", "color"] + ) + ) + agg_data_frame = ( + pd.concat(agg_data_frame_list, axis=0, keys=groups.keys()) + .rename_axis(index=("frame", "index")) + .reset_index("frame") + ) + + agg_data_frame["color"] = pd.to_numeric(agg_data_frame["color"]) + + if range_color is None: + range_color = [agg_data_frame["color"].min(), agg_data_frame["color"].max()] + + fig = choropleth_mapbox( + data_frame=agg_data_frame, + geojson=geojson, + locations="locations", + color="color", + hover_data={"color": True, "locations": False, "frame": False}, + animation_frame=("frame" if args["animation_frame"] is not None else None), + color_discrete_sequence=color_discrete_sequence, + color_discrete_map=color_discrete_map, + labels=labels, + color_continuous_scale=color_continuous_scale, + range_color=range_color, + color_continuous_midpoint=color_continuous_midpoint, + opacity=opacity, + zoom=zoom, + center=center, + mapbox_style=mapbox_style, + title=title, + template=template, + width=width, + height=height, + ) + + if show_original_data: + original_fig = scatter_mapbox( + data_frame=( + args["data_frame"].sort_values(by=args["animation_frame"]) + if args["animation_frame"] is not None + else args["data_frame"] + ), + lat=args["lat"], + lon=args["lon"], + animation_frame=args["animation_frame"], + ) + original_fig.data[0].hoverinfo = "skip" + original_fig.data[0].hovertemplate = None + original_fig.data[0].marker = original_data_marker + + fig.add_trace(original_fig.data[0]) + + if args["animation_frame"] is not None: + for i in range(len(original_fig.frames)): + original_fig.frames[i].data[0].hoverinfo = "skip" + original_fig.frames[i].data[0].hovertemplate = None + original_fig.frames[i].data[0].marker = original_data_marker + + fig.frames[i].data = [ + fig.frames[i].data[0], + original_fig.frames[i].data[0], + ] + + return fig + + +create_hexbin_mapbox.__doc__ = make_docstring( + create_hexbin_mapbox, + override_dict=dict( + nx_hexagon=["int", "Number of hexagons (horizontally) to be created"], + agg_func=[ + "function", + "Numpy array aggregator, it must take as input a 1D array", + "and output a scalar value.", + ], + min_count=[ + "int", + "Minimum number of points in a hexagon for it to be displayed.", + "If None and color is not set, display all hexagons.", + "If None and color is set, only display hexagons that contain points.", + ], + show_original_data=[ + "bool", + "Whether to show the original data on top of the hexbin aggregation.", + ], + original_data_marker=["dict", "Scattermapbox marker options."], + ), +) diff --git a/packages/python/plotly/plotly/tests/test_optional/test_figure_factory/test_figure_factory.py b/packages/python/plotly/plotly/tests/test_optional/test_figure_factory/test_figure_factory.py index 0106b9dc58f..807b0f60423 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_figure_factory/test_figure_factory.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_figure_factory/test_figure_factory.py @@ -4307,3 +4307,202 @@ def test_optional_arguments(self): # This test does not work for ilr interpolation print(len(fig.data)) assert len(fig.data) == ncontours + 2 + arg_set["showscale"] + + +class TestHexbinMapbox(NumpyTestUtilsMixin, TestCaseNoTemplate): + def test_aggregation(self): + + lat = [0, 1, 1, 2, 4, 5, 1, 2, 4, 5, 2, 3, 2, 1, 5, 3, 5] + lon = [1, 2, 3, 3, 0, 4, 5, 0, 5, 3, 1, 5, 4, 0, 1, 2, 5] + color = np.ones(len(lat)) + + fig1 = ff.create_hexbin_mapbox(lat=lat, lon=lon, nx_hexagon=1) + + actual_geojson = { + "type": "FeatureCollection", + "features": [ + { + "type": "Feature", + "id": "-8.726646259971648e-11,-0.031886255679892235", + "geometry": { + "type": "Polygon", + "coordinates": [ + [ + [-5e-09, -4.7083909316316985], + [2.4999999999999996, -3.268549270944215], + [2.4999999999999996, -0.38356933397072673], + [-5e-09, 1.0597430482129082], + [-2.50000001, -0.38356933397072673], + [-2.50000001, -3.268549270944215], + [-5e-09, -4.7083909316316985], + ] + ], + }, + }, + { + "type": "Feature", + "id": "-8.726646259971648e-11,0.1192636916419258", + "geometry": { + "type": "Polygon", + "coordinates": [ + [ + [-5e-09, 3.9434377827164666], + [2.4999999999999996, 5.381998306154031], + [2.4999999999999996, 8.248045720432454], + [-5e-09, 9.673766164509932], + [-2.50000001, 8.248045720432454], + [-2.50000001, 5.381998306154031], + [-5e-09, 3.9434377827164666], + ] + ], + }, + }, + { + "type": "Feature", + "id": "0.08726646268698293,-0.031886255679892235", + "geometry": { + "type": "Polygon", + "coordinates": [ + [ + [5.0000000049999995, -4.7083909316316985], + [7.500000009999999, -3.268549270944215], + [7.500000009999999, -0.38356933397072673], + [5.0000000049999995, 1.0597430482129082], + [2.5, -0.38356933397072673], + [2.5, -3.268549270944215], + [5.0000000049999995, -4.7083909316316985], + ] + ], + }, + }, + { + "type": "Feature", + "id": "0.08726646268698293,0.1192636916419258", + "geometry": { + "type": "Polygon", + "coordinates": [ + [ + [5.0000000049999995, 3.9434377827164666], + [7.500000009999999, 5.381998306154031], + [7.500000009999999, 8.248045720432454], + [5.0000000049999995, 9.673766164509932], + [2.5, 8.248045720432454], + [2.5, 5.381998306154031], + [5.0000000049999995, 3.9434377827164666], + ] + ], + }, + }, + { + "type": "Feature", + "id": "0.04363323129985823,0.04368871798101678", + "geometry": { + "type": "Polygon", + "coordinates": [ + [ + [2.4999999999999996, -0.38356933397072673], + [5.0000000049999995, 1.0597430482129082], + [5.0000000049999995, 3.9434377827164666], + [2.4999999999999996, 5.381998306154031], + [-5.0000001310894304e-09, 3.9434377827164666], + [-5.0000001310894304e-09, 1.0597430482129082], + [2.4999999999999996, -0.38356933397072673], + ] + ], + }, + }, + ], + } + + actual_agg = [2.0, 2.0, 1.0, 3.0, 9.0] + + self.assert_dict_equal(fig1.data[0].geojson, actual_geojson) + assert np.array_equal(fig1.data[0].z, actual_agg) + + fig2 = ff.create_hexbin_mapbox( + lat=lat, lon=lon, nx_hexagon=1, color=color, agg_func=np.mean, + ) + + assert np.array_equal(fig2.data[0].z, np.ones(5)) + + fig3 = ff.create_hexbin_mapbox( + lat=np.random.randn(1000), lon=np.random.randn(1000), nx_hexagon=20, + ) + + assert fig3.data[0].z.sum() == 1000 + + def test_build_dataframe(self): + np.random.seed(0) + N = 10000 + nx_hexagon = 20 + n_frames = 3 + + lat = np.random.randn(N) + lon = np.random.randn(N) + color = np.ones(N) + frame = np.random.randint(0, n_frames, N) + df = pd.DataFrame( + np.c_[lat, lon, color, frame], + columns=["Latitude", "Longitude", "Metric", "Frame"], + ) + + fig1 = ff.create_hexbin_mapbox(lat=lat, lon=lon, nx_hexagon=nx_hexagon) + fig2 = ff.create_hexbin_mapbox( + data_frame=df, lat="Latitude", lon="Longitude", nx_hexagon=nx_hexagon + ) + + assert isinstance(fig1, go.Figure) + assert len(fig1.data) == 1 + self.assert_dict_equal( + fig1.to_plotly_json()["data"][0], fig2.to_plotly_json()["data"][0] + ) + + fig3 = ff.create_hexbin_mapbox( + lat=lat, + lon=lon, + nx_hexagon=nx_hexagon, + color=color, + agg_func=np.sum, + min_count=0, + ) + fig4 = ff.create_hexbin_mapbox( + lat=lat, lon=lon, nx_hexagon=nx_hexagon, color=color, agg_func=np.sum, + ) + fig5 = ff.create_hexbin_mapbox( + data_frame=df, + lat="Latitude", + lon="Longitude", + nx_hexagon=nx_hexagon, + color="Metric", + agg_func=np.sum, + ) + + self.assert_dict_equal( + fig1.to_plotly_json()["data"][0], fig3.to_plotly_json()["data"][0] + ) + self.assert_dict_equal( + fig4.to_plotly_json()["data"][0], fig5.to_plotly_json()["data"][0] + ) + + fig6 = ff.create_hexbin_mapbox( + data_frame=df, + lat="Latitude", + lon="Longitude", + nx_hexagon=nx_hexagon, + color="Metric", + agg_func=np.sum, + animation_frame="Frame", + ) + + fig7 = ff.create_hexbin_mapbox( + lat=lat, + lon=lon, + nx_hexagon=nx_hexagon, + color=color, + agg_func=np.sum, + animation_frame=frame, + ) + + assert len(fig6.frames) == n_frames + assert len(fig7.frames) == n_frames + assert fig6.data[0].geojson == fig1.data[0].geojson