Skip to content

Commit 16faeb4

Browse files
Add fast pow video
1 parent a3e5b34 commit 16faeb4

File tree

2 files changed

+116
-0
lines changed

2 files changed

+116
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ James and his team are available for consulting, contracting, code reviews, and
1212

1313
| N | Code | Video |
1414
| -- | --- |--- |
15+
| 115 | [src](videos/115_fast_pow) | [Fast pow](https://youtu.be/GrNJE6ogyQU) |
1516
| 114 | [src](videos/114_copy_or_no_copy) | [Python Iterators! COPY or NO COPY?](https://youtu.be/hVFKy9Gw95c) |
1617
| 113 | [src](videos/113_getting_rid_of_recursion) | [Getting around the recursion limit](https://youtu.be/1dUpHL5Yg8E) |
1718
| 112 | [src](videos/112_python_name_mangling) | [Every Python dev falls for this (name mangling)](https://youtu.be/0hrEaA3N3lk) |

videos/115_fast_pow/fast_pow.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from collections.abc import Callable
2+
from typing import TypeVar
3+
import operator
4+
5+
import matplotlib.pyplot as plt
6+
7+
import numpy as np
8+
9+
T = TypeVar('T')
10+
BinOp = Callable[[T, T], T]
11+
12+
"""
13+
Idea:
14+
n == 2 * n // 2 + n % 2
15+
Therefore:
16+
x ** n == x ** (n % 2) * (x ** (n//2)) ** 2
17+
E.g.
18+
x ** 31 == x * (x ** 15) ** 2
19+
"""
20+
21+
22+
def slow_pow(x, n: int):
23+
res = 1
24+
for _ in range(n):
25+
res *= x
26+
return res
27+
28+
29+
def fast_pow(x, n: int):
30+
if n == 0:
31+
return 1
32+
half_n, rem = divmod(n, 2) # n // 2, n % 2
33+
res = fast_pow(x, half_n)
34+
res = res * res
35+
return x * res if rem else res
36+
37+
38+
def fast_pow_monoid_strategy(mul: BinOp[T], identity: T, x: T, n: int) -> T:
39+
if n < 0:
40+
raise ValueError(f"n must be >= 0, but got {n}")
41+
if n == 0:
42+
return identity
43+
half_n, rem = divmod(n, 2)
44+
res = fast_pow_monoid_strategy(mul, identity, x, half_n)
45+
res = mul(res, res)
46+
return mul(x, res) if rem else res
47+
48+
49+
def fast_pow_semigroup_strategy(mul: BinOp[T], x: T, n: int) -> T:
50+
if n < 1:
51+
raise ValueError(f"n must be > 0, but got {n}")
52+
if n == 1:
53+
return x
54+
half_n, rem = divmod(n, 2)
55+
res = fast_pow_semigroup_strategy(mul, x, half_n)
56+
res = mul(res, res)
57+
return mul(x, res) if rem else res
58+
59+
60+
def fast_pow_int_examples():
61+
assert slow_pow(1, 100) == 1
62+
assert slow_pow(5, 3) == 125
63+
assert slow_pow(2, 11) == 2048
64+
65+
assert fast_pow(1, 100) == 1
66+
assert fast_pow(5, 3) == 125
67+
assert fast_pow(2, 11) == 2048
68+
69+
70+
def fast_pow_monoid_examples():
71+
assert fast_pow_monoid_strategy(operator.mul, 1, 1, 100) == 1
72+
assert fast_pow_monoid_strategy(operator.mul, 1, 5, 3) == 125
73+
assert fast_pow_monoid_strategy(operator.mul, 1, 2, 11) == 2048
74+
75+
assert fast_pow_monoid_strategy(operator.add, 0, 1, 100) == 100
76+
assert fast_pow_monoid_strategy(operator.add, 0, 5, 3) == 15
77+
assert fast_pow_monoid_strategy(operator.add, 0, 2, 11) == 22
78+
79+
assert fast_pow_monoid_strategy(operator.concat, "", "abc", 5) == "abcabcabcabcabc"
80+
81+
82+
def fast_pow_semigroup_examples():
83+
assert fast_pow_semigroup_strategy(operator.mul, 1, 100) == 1
84+
assert fast_pow_semigroup_strategy(operator.mul, 5, 3) == 125
85+
assert fast_pow_semigroup_strategy(operator.mul, 2, 11) == 2048
86+
87+
assert fast_pow_semigroup_strategy(operator.add, 1, 100) == 100
88+
assert fast_pow_semigroup_strategy(operator.add, 5, 3) == 15
89+
assert fast_pow_semigroup_strategy(operator.add, 2, 11) == 22
90+
91+
assert fast_pow_semigroup_strategy(operator.concat, "abc", 5) == "abcabcabcabcabc"
92+
93+
n_samples = 1000
94+
x = np.linspace(0, 1, n_samples)
95+
uniform_dist = np.ones_like(x) / n_samples
96+
assert np.isclose(np.sum(uniform_dist), 1.0), np.sum(uniform_dist)
97+
98+
for n in [1, 2, 3, 4, 5]:
99+
smoothed_dist = fast_pow_semigroup_strategy(np.convolve, uniform_dist, n)
100+
assert np.isclose(np.sum(smoothed_dist), 1.0), np.sum(smoothed_dist)
101+
xn = np.linspace(x[0] * n, x[-1] * n, len(smoothed_dist))
102+
normalizing_constant = len(smoothed_dist) / (xn[-1] - xn[0])
103+
plt.plot(xn, smoothed_dist * normalizing_constant)
104+
105+
plt.show()
106+
107+
108+
def main():
109+
fast_pow_int_examples()
110+
fast_pow_monoid_examples()
111+
fast_pow_semigroup_examples()
112+
113+
114+
if __name__ == '__main__':
115+
main()

0 commit comments

Comments
 (0)