Skip to content

Commit 20d0194

Browse files
alembckebehacklMrDiverpre-commit-ci[bot]chopan050
authored
Added colorscale to axes.plot() (#3148)
* add colorscale to plot * Update manim/mobject/graphing/coordinate_systems.py Co-authored-by: Benjamin Hackl <[email protected]> * updated typing and moved one line * added test * fix input_to_graph_point error * Performance improvement by using cairo color drawing * Add OpenGL support * Add OpenGL tests and split test for x and y axis for more behavior coverage * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Updated gradient_line tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Benjamin Hackl <[email protected]> Co-authored-by: MrDiver <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Francisco Manríquez Novoa <[email protected]> Co-authored-by: chopan <[email protected]>
1 parent 3a4ab4c commit 20d0194

File tree

7 files changed

+123
-0
lines changed

7 files changed

+123
-0
lines changed

manim/mobject/graphing/coordinate_systems.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
ManimColor,
4949
ParsableManimColor,
5050
color_gradient,
51+
interpolate_color,
5152
invert_color,
5253
)
5354
from manim.utils.config_ops import merge_dicts_recursively, update_dict_recursively
@@ -628,6 +629,8 @@ def plot(
628629
function: Callable[[float], float],
629630
x_range: Sequence[float] | None = None,
630631
use_vectorized: bool = False,
632+
colorscale: Union[Iterable[Color], Iterable[Color, float]] | None = None,
633+
colorscale_axis: int = 1,
631634
**kwargs: Any,
632635
) -> ParametricFunction:
633636
"""Generates a curve based on a function.
@@ -641,6 +644,12 @@ def plot(
641644
use_vectorized
642645
Whether to pass in the generated t value array to the function. Only use this if your function supports it.
643646
Output should be a numpy array of shape ``[y_0, y_1, ...]``
647+
colorscale
648+
Colors of the function. Optional parameter used when coloring a function by values. Passing a list of colors
649+
and a colorscale_axis will color the function by y-value. Passing a list of tuples in the form ``(color, pivot)``
650+
allows user-defined pivots where the color transitions.
651+
colorscale_axis
652+
Defines the axis on which the colorscale is applied (0 = x, 1 = y), default is y-axis (1).
644653
kwargs
645654
Additional parameters to be passed to :class:`~.ParametricFunction`.
646655
@@ -719,7 +728,57 @@ def log_func(x):
719728
use_vectorized=use_vectorized,
720729
**kwargs,
721730
)
731+
722732
graph.underlying_function = function
733+
734+
if colorscale:
735+
if type(colorscale[0]) in (list, tuple):
736+
new_colors, pivots = [
737+
[i for i, j in colorscale],
738+
[j for i, j in colorscale],
739+
]
740+
else:
741+
new_colors = colorscale
742+
743+
ranges = [self.x_range, self.y_range]
744+
pivot_min = ranges[colorscale_axis][0]
745+
pivot_max = ranges[colorscale_axis][1]
746+
pivot_frequency = (pivot_max - pivot_min) / (len(new_colors) - 1)
747+
pivots = np.arange(
748+
start=pivot_min,
749+
stop=pivot_max + pivot_frequency,
750+
step=pivot_frequency,
751+
)
752+
753+
resolution = 0.01 if len(x_range) == 2 else x_range[2]
754+
sample_points = np.arange(x_range[0], x_range[1] + resolution, resolution)
755+
color_list = []
756+
for samp_x in sample_points:
757+
axis_value = (samp_x, function(samp_x))[colorscale_axis]
758+
if axis_value <= pivots[0]:
759+
color_list.append(new_colors[0])
760+
elif axis_value >= pivots[-1]:
761+
color_list.append(new_colors[-1])
762+
else:
763+
for i, pivot in enumerate(pivots):
764+
if pivot > axis_value:
765+
color_index = (axis_value - pivots[i - 1]) / (
766+
pivots[i] - pivots[i - 1]
767+
)
768+
color_index = min(color_index, 1)
769+
mob_color = interpolate_color(
770+
new_colors[i - 1],
771+
new_colors[i],
772+
color_index,
773+
)
774+
color_list.append(mob_color)
775+
break
776+
if config.renderer == RendererType.OPENGL:
777+
graph.set_color(color_list)
778+
else:
779+
graph.set_stroke(color_list)
780+
graph.set_sheen_direction(RIGHT)
781+
723782
return graph
724783

725784
def plot_implicit_curve(

tests/opengl/test_coordinate_system_opengl.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
tempconfig,
2121
)
2222
from manim import CoordinateSystem as CS
23+
from manim.utils.color import BLUE, GREEN, ORANGE, RED, YELLOW
24+
from manim.utils.testing.frames_comparison import frames_comparison
25+
26+
__module_test__ = "coordinate_system_opengl"
2327

2428

2529
def test_initial_config(using_opengl_renderer):
@@ -138,3 +142,33 @@ def test_input_to_graph_point(using_opengl_renderer):
138142
# test the line_graph implementation
139143
position = np.around(ax.input_to_graph_point(x=PI, graph=line_graph), decimals=4)
140144
np.testing.assert_array_equal(position, (2.6928, 1.2876, 0))
145+
146+
147+
@frames_comparison
148+
def test_gradient_line_graph_x_axis(scene, using_opengl_renderer):
149+
"""Test that using `colorscale` generates a line whose gradient matches the y-axis"""
150+
axes = Axes(x_range=[-3, 3], y_range=[-3, 3])
151+
152+
curve = axes.plot(
153+
lambda x: 0.1 * x**3,
154+
x_range=(-3, 3, 0.001),
155+
colorscale=[BLUE, GREEN, YELLOW, ORANGE, RED],
156+
colorscale_axis=0,
157+
)
158+
159+
scene.add(axes, curve)
160+
161+
162+
@frames_comparison
163+
def test_gradient_line_graph_y_axis(scene, using_opengl_renderer):
164+
"""Test that using `colorscale` generates a line whose gradient matches the y-axis"""
165+
axes = Axes(x_range=[-3, 3], y_range=[-3, 3])
166+
167+
curve = axes.plot(
168+
lambda x: 0.1 * x**3,
169+
x_range=(-3, 3, 0.001),
170+
colorscale=[BLUE, GREEN, YELLOW, ORANGE, RED],
171+
colorscale_axis=1,
172+
)
173+
174+
scene.add(axes, curve)

tests/test_graphical_units/test_coordinate_systems.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,33 @@ def test_number_plane_log(scene):
141141
)
142142

143143
scene.add(VGroup(plane1, plane2).arrange())
144+
145+
146+
@frames_comparison
147+
def test_gradient_line_graph_x_axis(scene):
148+
"""Test that using `colorscale` generates a line whose gradient matches the y-axis"""
149+
axes = Axes(x_range=[-3, 3], y_range=[-3, 3])
150+
151+
curve = axes.plot(
152+
lambda x: 0.1 * x**3,
153+
x_range=(-3, 3, 0.001),
154+
colorscale=[BLUE, GREEN, YELLOW, ORANGE, RED],
155+
colorscale_axis=0,
156+
)
157+
158+
scene.add(axes, curve)
159+
160+
161+
@frames_comparison
162+
def test_gradient_line_graph_y_axis(scene):
163+
"""Test that using `colorscale` generates a line whose gradient matches the y-axis"""
164+
axes = Axes(x_range=[-3, 3], y_range=[-3, 3])
165+
166+
curve = axes.plot(
167+
lambda x: 0.1 * x**3,
168+
x_range=(-3, 3, 0.001),
169+
colorscale=[BLUE, GREEN, YELLOW, ORANGE, RED],
170+
colorscale_axis=1,
171+
)
172+
173+
scene.add(axes, curve)

0 commit comments

Comments
 (0)