|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import pytest |
| 5 | + |
| 6 | +from manim.mobject.matrix import ( |
| 7 | + DecimalMatrix, |
| 8 | + IntegerMatrix, |
| 9 | + Matrix, |
| 10 | +) |
| 11 | +from manim.mobject.text.tex_mobject import MathTex |
| 12 | +from manim.mobject.types.vectorized_mobject import VGroup |
| 13 | + |
| 14 | + |
| 15 | +class TestMatrix: |
| 16 | + @pytest.mark.parametrize( |
| 17 | + ( |
| 18 | + "matrix_elements", |
| 19 | + "left_bracket", |
| 20 | + "right_bracket", |
| 21 | + "expected_rows", |
| 22 | + "expected_columns", |
| 23 | + ), |
| 24 | + [ |
| 25 | + ([[1, 2], [3, 4]], "[", "]", 2, 2), |
| 26 | + ([[1, 2, 3]], "[", "]", 1, 3), |
| 27 | + ([[1], [2], [3]], "[", "]", 3, 1), |
| 28 | + ([[5]], "[", "]", 1, 1), |
| 29 | + ([[1, 0], [0, 1]], "(", ")", 2, 2), |
| 30 | + ([["a", "b"], ["c", "d"]], "[", "]", 2, 2), |
| 31 | + (np.array([[10, 20], [30, 40]]), "[", "]", 2, 2), |
| 32 | + ], |
| 33 | + ids=[ |
| 34 | + "2x2_default", |
| 35 | + "1x3_default", |
| 36 | + "3x1_default", |
| 37 | + "1x1_default", |
| 38 | + "2x2_parentheses", |
| 39 | + "2x2_strings", |
| 40 | + "2x2_numpy", |
| 41 | + ], |
| 42 | + ) |
| 43 | + def test_matrix_init_valid( |
| 44 | + self, |
| 45 | + matrix_elements, |
| 46 | + left_bracket, |
| 47 | + right_bracket, |
| 48 | + expected_rows, |
| 49 | + expected_columns, |
| 50 | + ): |
| 51 | + matrix = Matrix( |
| 52 | + matrix_elements, left_bracket=left_bracket, right_bracket=right_bracket |
| 53 | + ) |
| 54 | + |
| 55 | + assert isinstance(matrix, Matrix) |
| 56 | + assert matrix.left_bracket == left_bracket |
| 57 | + assert matrix.right_bracket == right_bracket |
| 58 | + assert len(matrix.get_rows()) == expected_rows |
| 59 | + assert len(matrix.get_columns()) == expected_columns |
| 60 | + |
| 61 | + @pytest.mark.parametrize( |
| 62 | + ("invalid_elements", "expected_error"), |
| 63 | + [ |
| 64 | + (10, TypeError), |
| 65 | + (10.4, TypeError), |
| 66 | + ([1, 2, 3], TypeError), |
| 67 | + ], |
| 68 | + ids=[ |
| 69 | + "integer", |
| 70 | + "float", |
| 71 | + "flat_list", |
| 72 | + ], |
| 73 | + ) |
| 74 | + def test_matrix_init_invalid(self, invalid_elements, expected_error): |
| 75 | + with pytest.raises(expected_error): |
| 76 | + Matrix(invalid_elements) |
| 77 | + |
| 78 | + @pytest.mark.parametrize( |
| 79 | + ("matrix_elements", "expected_columns"), |
| 80 | + [ |
| 81 | + ([[1, 2], [3, 4]], 2), |
| 82 | + ([[1, 2, 3]], 3), |
| 83 | + ([[1], [2], [3]], 1), |
| 84 | + ], |
| 85 | + ids=["2x2", "1x3", "3x1"], |
| 86 | + ) |
| 87 | + def test_get_columns(self, matrix_elements, expected_columns): |
| 88 | + matrix = Matrix(matrix_elements) |
| 89 | + |
| 90 | + assert isinstance(matrix, Matrix) |
| 91 | + assert len(matrix.get_columns()) == expected_columns |
| 92 | + for column in matrix.get_columns(): |
| 93 | + assert isinstance(column, VGroup) |
| 94 | + |
| 95 | + @pytest.mark.parametrize( |
| 96 | + ("matrix_elements", "expected_rows"), |
| 97 | + [ |
| 98 | + ([[1, 2], [3, 4]], 2), |
| 99 | + ([[1, 2, 3]], 1), |
| 100 | + ([[1], [2], [3]], 3), |
| 101 | + ], |
| 102 | + ids=["2x2", "1x3", "3x1"], |
| 103 | + ) |
| 104 | + def test_get_rows(self, matrix_elements, expected_rows): |
| 105 | + matrix = Matrix(matrix_elements) |
| 106 | + |
| 107 | + assert isinstance(matrix, Matrix) |
| 108 | + assert len(matrix.get_rows()) == expected_rows |
| 109 | + for row in matrix.get_rows(): |
| 110 | + assert isinstance(row, VGroup) |
| 111 | + |
| 112 | + @pytest.mark.parametrize( |
| 113 | + ("matrix_elements", "expected_entries_tex_string", "expected_entries_count"), |
| 114 | + [ |
| 115 | + ([[1, 2], [3, 4]], ["1", "2", "3", "4"], 4), |
| 116 | + ([[1, 2, 3]], ["1", "2", "3"], 3), |
| 117 | + ], |
| 118 | + ids=["2x2", "1x3"], |
| 119 | + ) |
| 120 | + def test_get_entries( |
| 121 | + self, matrix_elements, expected_entries_tex_string, expected_entries_count |
| 122 | + ): |
| 123 | + matrix = Matrix(matrix_elements) |
| 124 | + entries = matrix.get_entries() |
| 125 | + |
| 126 | + assert isinstance(matrix, Matrix) |
| 127 | + assert len(entries) == expected_entries_count |
| 128 | + for index_entry, entry in enumerate(entries): |
| 129 | + assert isinstance(entry, MathTex) |
| 130 | + assert expected_entries_tex_string[index_entry] == entry.tex_string |
| 131 | + |
| 132 | + @pytest.mark.parametrize( |
| 133 | + ("matrix_elements", "row", "column", "expected_value_str"), |
| 134 | + [ |
| 135 | + ([[1, 2], [3, 4]], 0, 0, "1"), |
| 136 | + ([[1, 2], [3, 4]], 1, 1, "4"), |
| 137 | + ([[1, 2, 3]], 0, 2, "3"), |
| 138 | + ([[1], [2], [3]], 2, 0, "3"), |
| 139 | + ], |
| 140 | + ids=["2x2_00", "2x2_11", "1x3_02", "3x1_20"], |
| 141 | + ) |
| 142 | + def test_get_element(self, matrix_elements, row, column, expected_value_str): |
| 143 | + matrix = Matrix(matrix_elements) |
| 144 | + |
| 145 | + assert isinstance(matrix.get_columns()[column][row], MathTex) |
| 146 | + assert isinstance(matrix.get_rows()[row][column], MathTex) |
| 147 | + assert matrix.get_columns()[column][row].tex_string == expected_value_str |
| 148 | + assert matrix.get_rows()[row][column].tex_string == expected_value_str |
| 149 | + |
| 150 | + @pytest.mark.parametrize( |
| 151 | + ("matrix_elements", "row", "column", "expected_error"), |
| 152 | + [ |
| 153 | + ([[1, 2]], 1, 0, IndexError), |
| 154 | + ([[1, 2]], 0, 2, IndexError), |
| 155 | + ], |
| 156 | + ids=["row_out_of_bounds", "col_out_of_bounds"], |
| 157 | + ) |
| 158 | + def test_get_element_invalid(self, matrix_elements, row, column, expected_error): |
| 159 | + matrix = Matrix(matrix_elements) |
| 160 | + |
| 161 | + with pytest.raises(expected_error): |
| 162 | + matrix.get_columns()[column][row] |
| 163 | + |
| 164 | + with pytest.raises(expected_error): |
| 165 | + matrix.get_rows()[row][column] |
| 166 | + |
| 167 | + |
| 168 | +class TestDecimalMatrix: |
| 169 | + @pytest.mark.parametrize( |
| 170 | + ("matrix_elements", "num_decimal_places", "expected_elements"), |
| 171 | + [ |
| 172 | + ([[1.234, 5.678], [9.012, 3.456]], 2, [[1.234, 5.678], [9.012, 3.456]]), |
| 173 | + ([[1.0, 2.0], [3.0, 4.0]], 0, [[1, 2], [3, 4]]), |
| 174 | + ([[1, 2.3], [4.567, 7]], 1, [[1.0, 2.3], [4.567, 7.0]]), |
| 175 | + ], |
| 176 | + ids=[ |
| 177 | + "basic_2_decimal_points", |
| 178 | + "basic_0_decimal_points", |
| 179 | + "mixed_1_decimal_points", |
| 180 | + ], |
| 181 | + ) |
| 182 | + def test_decimal_matrix_init( |
| 183 | + self, matrix_elements, num_decimal_places, expected_elements |
| 184 | + ): |
| 185 | + matrix = DecimalMatrix( |
| 186 | + matrix_elements, |
| 187 | + element_to_mobject_config={"num_decimal_places": num_decimal_places}, |
| 188 | + ) |
| 189 | + |
| 190 | + assert isinstance(matrix, DecimalMatrix) |
| 191 | + for column_index, column in enumerate(matrix.get_columns()): |
| 192 | + for row_index, element in enumerate(column): |
| 193 | + assert element.number == expected_elements[row_index][column_index] |
| 194 | + assert element.num_decimal_places == num_decimal_places |
| 195 | + |
| 196 | + |
| 197 | +class TestIntegerMatrix: |
| 198 | + @pytest.mark.parametrize( |
| 199 | + ("matrix_elements", "expected_elements"), |
| 200 | + [ |
| 201 | + ([[1, 2], [3, 4]], [[1, 2], [3, 4]]), |
| 202 | + ([[1.2, 2.8], [3.5, 4]], [[1.2, 2.8], [3.5, 4]]), |
| 203 | + ], |
| 204 | + ids=["basic_int", "mixed_float_int"], |
| 205 | + ) |
| 206 | + def test_integer_matrix_init(self, matrix_elements, expected_elements): |
| 207 | + matrix = IntegerMatrix(matrix_elements) |
| 208 | + |
| 209 | + assert isinstance(matrix, IntegerMatrix) |
| 210 | + for row_index, row in enumerate(matrix.get_rows()): |
| 211 | + for column_index, element in enumerate(row): |
| 212 | + assert element.number == expected_elements[row_index][column_index] |
0 commit comments