Skip to content

Commit e569fd4

Browse files
authored
test: add matrix tests (#4279)
1 parent bf72127 commit e569fd4

File tree

1 file changed

+212
-0
lines changed

1 file changed

+212
-0
lines changed

tests/module/mobject/test_matrix.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
from __future__ import annotations
2+
3+
import numpy as np
4+
import pytest
5+
6+
from manim.mobject.matrix import (
7+
DecimalMatrix,
8+
IntegerMatrix,
9+
Matrix,
10+
)
11+
from manim.mobject.text.tex_mobject import MathTex
12+
from manim.mobject.types.vectorized_mobject import VGroup
13+
14+
15+
class TestMatrix:
16+
@pytest.mark.parametrize(
17+
(
18+
"matrix_elements",
19+
"left_bracket",
20+
"right_bracket",
21+
"expected_rows",
22+
"expected_columns",
23+
),
24+
[
25+
([[1, 2], [3, 4]], "[", "]", 2, 2),
26+
([[1, 2, 3]], "[", "]", 1, 3),
27+
([[1], [2], [3]], "[", "]", 3, 1),
28+
([[5]], "[", "]", 1, 1),
29+
([[1, 0], [0, 1]], "(", ")", 2, 2),
30+
([["a", "b"], ["c", "d"]], "[", "]", 2, 2),
31+
(np.array([[10, 20], [30, 40]]), "[", "]", 2, 2),
32+
],
33+
ids=[
34+
"2x2_default",
35+
"1x3_default",
36+
"3x1_default",
37+
"1x1_default",
38+
"2x2_parentheses",
39+
"2x2_strings",
40+
"2x2_numpy",
41+
],
42+
)
43+
def test_matrix_init_valid(
44+
self,
45+
matrix_elements,
46+
left_bracket,
47+
right_bracket,
48+
expected_rows,
49+
expected_columns,
50+
):
51+
matrix = Matrix(
52+
matrix_elements, left_bracket=left_bracket, right_bracket=right_bracket
53+
)
54+
55+
assert isinstance(matrix, Matrix)
56+
assert matrix.left_bracket == left_bracket
57+
assert matrix.right_bracket == right_bracket
58+
assert len(matrix.get_rows()) == expected_rows
59+
assert len(matrix.get_columns()) == expected_columns
60+
61+
@pytest.mark.parametrize(
62+
("invalid_elements", "expected_error"),
63+
[
64+
(10, TypeError),
65+
(10.4, TypeError),
66+
([1, 2, 3], TypeError),
67+
],
68+
ids=[
69+
"integer",
70+
"float",
71+
"flat_list",
72+
],
73+
)
74+
def test_matrix_init_invalid(self, invalid_elements, expected_error):
75+
with pytest.raises(expected_error):
76+
Matrix(invalid_elements)
77+
78+
@pytest.mark.parametrize(
79+
("matrix_elements", "expected_columns"),
80+
[
81+
([[1, 2], [3, 4]], 2),
82+
([[1, 2, 3]], 3),
83+
([[1], [2], [3]], 1),
84+
],
85+
ids=["2x2", "1x3", "3x1"],
86+
)
87+
def test_get_columns(self, matrix_elements, expected_columns):
88+
matrix = Matrix(matrix_elements)
89+
90+
assert isinstance(matrix, Matrix)
91+
assert len(matrix.get_columns()) == expected_columns
92+
for column in matrix.get_columns():
93+
assert isinstance(column, VGroup)
94+
95+
@pytest.mark.parametrize(
96+
("matrix_elements", "expected_rows"),
97+
[
98+
([[1, 2], [3, 4]], 2),
99+
([[1, 2, 3]], 1),
100+
([[1], [2], [3]], 3),
101+
],
102+
ids=["2x2", "1x3", "3x1"],
103+
)
104+
def test_get_rows(self, matrix_elements, expected_rows):
105+
matrix = Matrix(matrix_elements)
106+
107+
assert isinstance(matrix, Matrix)
108+
assert len(matrix.get_rows()) == expected_rows
109+
for row in matrix.get_rows():
110+
assert isinstance(row, VGroup)
111+
112+
@pytest.mark.parametrize(
113+
("matrix_elements", "expected_entries_tex_string", "expected_entries_count"),
114+
[
115+
([[1, 2], [3, 4]], ["1", "2", "3", "4"], 4),
116+
([[1, 2, 3]], ["1", "2", "3"], 3),
117+
],
118+
ids=["2x2", "1x3"],
119+
)
120+
def test_get_entries(
121+
self, matrix_elements, expected_entries_tex_string, expected_entries_count
122+
):
123+
matrix = Matrix(matrix_elements)
124+
entries = matrix.get_entries()
125+
126+
assert isinstance(matrix, Matrix)
127+
assert len(entries) == expected_entries_count
128+
for index_entry, entry in enumerate(entries):
129+
assert isinstance(entry, MathTex)
130+
assert expected_entries_tex_string[index_entry] == entry.tex_string
131+
132+
@pytest.mark.parametrize(
133+
("matrix_elements", "row", "column", "expected_value_str"),
134+
[
135+
([[1, 2], [3, 4]], 0, 0, "1"),
136+
([[1, 2], [3, 4]], 1, 1, "4"),
137+
([[1, 2, 3]], 0, 2, "3"),
138+
([[1], [2], [3]], 2, 0, "3"),
139+
],
140+
ids=["2x2_00", "2x2_11", "1x3_02", "3x1_20"],
141+
)
142+
def test_get_element(self, matrix_elements, row, column, expected_value_str):
143+
matrix = Matrix(matrix_elements)
144+
145+
assert isinstance(matrix.get_columns()[column][row], MathTex)
146+
assert isinstance(matrix.get_rows()[row][column], MathTex)
147+
assert matrix.get_columns()[column][row].tex_string == expected_value_str
148+
assert matrix.get_rows()[row][column].tex_string == expected_value_str
149+
150+
@pytest.mark.parametrize(
151+
("matrix_elements", "row", "column", "expected_error"),
152+
[
153+
([[1, 2]], 1, 0, IndexError),
154+
([[1, 2]], 0, 2, IndexError),
155+
],
156+
ids=["row_out_of_bounds", "col_out_of_bounds"],
157+
)
158+
def test_get_element_invalid(self, matrix_elements, row, column, expected_error):
159+
matrix = Matrix(matrix_elements)
160+
161+
with pytest.raises(expected_error):
162+
matrix.get_columns()[column][row]
163+
164+
with pytest.raises(expected_error):
165+
matrix.get_rows()[row][column]
166+
167+
168+
class TestDecimalMatrix:
169+
@pytest.mark.parametrize(
170+
("matrix_elements", "num_decimal_places", "expected_elements"),
171+
[
172+
([[1.234, 5.678], [9.012, 3.456]], 2, [[1.234, 5.678], [9.012, 3.456]]),
173+
([[1.0, 2.0], [3.0, 4.0]], 0, [[1, 2], [3, 4]]),
174+
([[1, 2.3], [4.567, 7]], 1, [[1.0, 2.3], [4.567, 7.0]]),
175+
],
176+
ids=[
177+
"basic_2_decimal_points",
178+
"basic_0_decimal_points",
179+
"mixed_1_decimal_points",
180+
],
181+
)
182+
def test_decimal_matrix_init(
183+
self, matrix_elements, num_decimal_places, expected_elements
184+
):
185+
matrix = DecimalMatrix(
186+
matrix_elements,
187+
element_to_mobject_config={"num_decimal_places": num_decimal_places},
188+
)
189+
190+
assert isinstance(matrix, DecimalMatrix)
191+
for column_index, column in enumerate(matrix.get_columns()):
192+
for row_index, element in enumerate(column):
193+
assert element.number == expected_elements[row_index][column_index]
194+
assert element.num_decimal_places == num_decimal_places
195+
196+
197+
class TestIntegerMatrix:
198+
@pytest.mark.parametrize(
199+
("matrix_elements", "expected_elements"),
200+
[
201+
([[1, 2], [3, 4]], [[1, 2], [3, 4]]),
202+
([[1.2, 2.8], [3.5, 4]], [[1.2, 2.8], [3.5, 4]]),
203+
],
204+
ids=["basic_int", "mixed_float_int"],
205+
)
206+
def test_integer_matrix_init(self, matrix_elements, expected_elements):
207+
matrix = IntegerMatrix(matrix_elements)
208+
209+
assert isinstance(matrix, IntegerMatrix)
210+
for row_index, row in enumerate(matrix.get_rows()):
211+
for column_index, element in enumerate(row):
212+
assert element.number == expected_elements[row_index][column_index]

0 commit comments

Comments
 (0)