Skip to content

Commit 8eff1e1

Browse files
committed
a cleaner version
1 parent 70d8681 commit 8eff1e1

File tree

2 files changed

+28
-30
lines changed

2 files changed

+28
-30
lines changed
Binary file not shown.

saddle-points/saddle_points.py

+28-30
Original file line numberDiff line numberDiff line change
@@ -17,30 +17,30 @@
1717
and apply the following steps:
1818
1919
Finding "row's max.", a column vector equal to [9, 5, 7], and use if to compute a boolean
20-
matrix in which each cell (i,j) is `True` if this cell's value is the maximum value it
21-
its own row. This second operation is done using `match`:
22-
[ 9 ] match [ 9 8 7 ] = [ O . . ]
23-
[ 5 ] [ 5 3 2 ] [ O . . ]
24-
[ 7 ] [ 6 6 7 ] [ . . O ]
20+
matrix in which each cell (i,j) is `True` if this cell's value is the maximum value of
21+
its own row. This second operation is done using `flag_projected_matches`:
22+
[ 9 ] flag_projected_matches [ 9 8 7 ] = [ O . . ]
23+
[ 5 ] [ 5 3 2 ] [ O . . ] = A
24+
[ 7 ] [ 6 6 7 ] [ . . O ]
2525
26+
("O" match, "." otherwise)
2627
2728
The same is done for "column's min.", we first compute the row vector representing min.
28-
column values and then compute the boolean mask using `match`:
29-
[ 5 3 2 ] match [ 9 8 7 ] = [ . . . ]
30-
[ 5 3 2 ] [ O O O ]
31-
[ 6 6 7 ] [ . . . ]
29+
column values and then compute the boolean mask using `flag_projected_matches`:
30+
[ 5 3 2 ] flag_projected_matches [ 9 8 7 ] = [ . . . ]
31+
[ 5 3 2 ] [ O O O ] = B
32+
[ 6 6 7 ] [ . . . ]
3233
33-
34-
Once we have these two boolean matrix, we just need to find cells where both conditions
35-
are met to find Saddle points (using the logical `&` operator):
34+
Once we have these two boolean matrix, in order to find Saddle points we just need
35+
to find cells for which both conditions (A and B) are met (using the logical `&` operator):
3636
[ O . . ] & [ . . . ] = [ . . . ]
3737
[ O . . ] [ O O O ] [ O . . ]
3838
[ . . O ] [ . . . ] [ . . . ]
3939
4040
Finally, Saddle points coordinates are retrieved using the `where` method on this last matrix:
41-
[ . . . ]
42-
[ O . . ] where {(1, 0)}
43-
[ . . . ]
41+
[ . . . ]
42+
where [ O . . ] = {(1, 0)}
43+
[ . . . ]
4444
"""
4545
from enum import Enum
4646
from collections import namedtuple
@@ -52,9 +52,9 @@ def saddle_points(data):
5252
and less than or equal to every element in its column.
5353
"""
5454
matrix = Matrix2D(data)
55-
rows_maximums = matrix.axis_map_reduce(Matrix2D.Axes.ROW, max)
55+
rows_maximums = matrix.axis_reduce(Matrix2D.Axes.ROW, max)
5656
is_max_in_row = matrix.flag_projected_matches(rows_maximums)
57-
columns_minimums = matrix.axis_map_reduce(Matrix2D.Axes.COLUMN, min)
57+
columns_minimums = matrix.axis_reduce(Matrix2D.Axes.COLUMN, min)
5858
is_min_in_col = matrix.flag_projected_matches(columns_minimums)
5959
is_saddle = is_max_in_row & is_min_in_col
6060
saddle_indexes = set(is_saddle.where())
@@ -121,7 +121,7 @@ def T(self):
121121
self._transpose = Matrix2D(list(map(list, zip(*self))))
122122
return self._transpose
123123

124-
def axis_map_reduce(self, axis, function):
124+
def axis_reduce(self, axis, function):
125125
"""Reduce the matrix using `function`. The optional parameter `axis`
126126
allows to reduce only along the given axis.
127127
@@ -138,17 +138,17 @@ def axis_map_reduce(self, axis, function):
138138
def flag_projected_matches(self, vector):
139139
"""Each row/column of the input matrix is compared to the input `vector`:
140140
141-
Row matching:
141+
Rows matching:
142142
143-
[1 2 3] match [1 5 3] = [O . O]
144-
[4 5 3] [. O O]
143+
[1 2 3] flag_projected_matches [1 5 3] = [O . O]
144+
[4 5 3] [. O O]
145145
146-
Column matching:
146+
Columns matching:
147147
148-
[1 2 3] match [1] = [O . .]
149-
[4 5 6] [5] [. O .]
148+
[1 2 3] flag_projected_matches [1] = [O . .]
149+
[4 5 6] [5] [. O .]
150150
151-
O shows matched items, . shows unmatched items (either it matched the vector/scalar or not).
151+
O shows matched items, . shows unmatched items (either it matched the vector or not).
152152
The function returns a matrix where matched cells contain `True` while unmatched ones contain `False`.
153153
"""
154154
return self.map_cells(lambda coordinates, cell: cell == vector[vector.projected_coordinates(coordinates)])
@@ -165,8 +165,8 @@ def all(self, condition=lambda x: x):
165165
all(condition(cell) for _, cell in self.enumerate_cells())
166166

167167
def map_cells(self, function):
168-
"""Returns a matrix in which `function` has been applied to all element.
169-
`function` will be called with three arguments (on each cell):
168+
"""Returns a matrix in which `function` has been applied to all elements.
169+
`function` will be called with the three following arguments (on each cell):
170170
- the `cell` value,
171171
- the row index and the
172172
- the column index.
@@ -186,9 +186,7 @@ def __eq__(self, other):
186186
return False
187187
return self.map_cells(lambda coordinates, cell: cell == other[coordinates]).all()
188188

189-
def _binary_operation(self, other, operation=None):
190-
if not operation:
191-
raise ValueError("Please provide a function.")
189+
def _binary_operation(self, other, operation):
192190
return Matrix2D([[operation(a, b) for (a, b) in zip(*rows)] for rows in zip(self, other)])
193191

194192
def __and__(self, other):

0 commit comments

Comments
 (0)