Sorry, we've misplaced that URL or it's pointing to something that doesn't exist.
Head back Home to try finding it again, or search for it on the Archives page.
diff --git a/404.html b/404.html new file mode 100644 index 0000000..51eeeb8 --- /dev/null +++ b/404.html @@ -0,0 +1 @@ +
Sorry, we've misplaced that URL or it's pointing to something that doesn't exist.
Head back Home to try finding it again, or search for it on the Archives page.
This site is a collection of problem solving techniques, with related problems and implementations in Python + CPP. It also houses a printable binder for use in contests (COMING SOON).
If you’d like to contribute, please let me know (on Discord, jackson#5358) and I can add you to the project.
Any and all recommendations for particular problems or algorithms is also appreciated.
a.height()*j||n-b.offset().top-b.outerHeight()-(a.height()*j)>0){return true}e("#disqus_thread").removeAttr("id");b.attr("id","disqus_thread").data("disqusLoaderStatus","loaded");if(f=="loaded"){DISQUS.reset({reload:true,config:i})}else{g.disqus_config=i;if(f=="unloaded"){f="loading";e.ajax({url:c,async:true,cache:true,dataType:"script",success:function(){f="loaded"}})}}};a.on("scroll resize",k(m,l));e.disqusLoader=function(o,n){n=e.extend({},{laziness:1,throttle:250,scriptUrl:false,disqusConfig:false},n);j=n.laziness+1;m=n.throttle;i=n.disqusConfig;c=c===false?n.scriptUrl:c;b=(typeof o=="string"?e(o):o).eq(0);b.data("disqusLoaderStatus","unloaded");l()}})(jQuery,window,document);
\ No newline at end of file
diff --git a/categories/contests/index.html b/categories/contests/index.html
new file mode 100644
index 0000000..739650c
--- /dev/null
+++ b/categories/contests/index.html
@@ -0,0 +1 @@
+ Contests | Monash Code Binder Contests 1
- Challenge Problems - 2021 Sem 2, Contest 1 Aug 23, 2021
diff --git a/categories/data-structures/index.html b/categories/data-structures/index.html
new file mode 100644
index 0000000..8544254
--- /dev/null
+++ b/categories/data-structures/index.html
@@ -0,0 +1 @@
+ Data Structures | Monash Code Binder Data Structures 3
- Union Find / DSU Dec 15, 2021
- Least Common Ancestor (LCA) Apr 20, 2021
- Dynamic Programming Mar 26, 2021
diff --git a/categories/index.html b/categories/index.html
new file mode 100644
index 0000000..ecec670
--- /dev/null
+++ b/categories/index.html
@@ -0,0 +1 @@
+ Categories | Monash Code Binder Categories
diff --git a/categories/math/index.html b/categories/math/index.html
new file mode 100644
index 0000000..5b1c67f
--- /dev/null
+++ b/categories/math/index.html
@@ -0,0 +1 @@
+ Math | Monash Code Binder Math 3
- DataStructureLess Competition 2023 Editorial Dec 28, 2023
- Primes and Factorization Techniques Apr 5, 2021
- Modular Arithmetic Mar 29, 2021
diff --git a/categories/trees/index.html b/categories/trees/index.html
new file mode 100644
index 0000000..7b824e4
--- /dev/null
+++ b/categories/trees/index.html
@@ -0,0 +1 @@
+ Trees | Monash Code Binder Trees 1
- Least Common Ancestor (LCA) Apr 20, 2021
diff --git a/feed.xml b/feed.xml
new file mode 100644
index 0000000..6bad25a
--- /dev/null
+++ b/feed.xml
@@ -0,0 +1 @@
+ https://monashaps.github.io// Monash Code Binder A collection of algorithms, explanations and training problems 2023-12-28T23:03:47+11:00 Monash Programming Team https://monashaps.github.io// Jekyll © 2023 Monash Programming Team /assets/img/favicons/favicon.ico /assets/img/favicons/favicon-96x96.png DataStructureLess Competition 2023 Editorial 2023-12-28T18:00:00+11:00 2023-12-28T23:02:40+11:00 https://monashaps.github.io//posts/dsless-editorial/ Monash Programming Team Since the intention of the DataStructureLess Competition was to showcase some interesting/unique solve techniques, I thought it would be good to provide some editorial for all of the problems so everyone can see some of the cool stuff on offer. Each problem has been given a few hints, so you can hopefully have a stab at the solution even if you got stuck in contest, but a solution is also prov... Union Find / DSU 2021-12-15T12:00:00+11:00 2021-12-27T13:53:13+11:00 https://monashaps.github.io//posts/uf/ Monash Programming Team Where is this useful? In many problems, translating into a graph structure can prove helpful, as we can describe our problem in very abstract terms. Once you’ve translated into this graph structure, often you might want to know whether two vertices are connected via a path, and if this is not the case, what two separate components they come from. Union Find allows us to not only answer this q... Challenge Problems - 2021 Sem 2, Contest 1 2021-08-23T11:00:00+10:00 2021-08-23T11:00:00+10:00 https://monashaps.github.io//posts/problems-21-s2-c1/ Monash Programming Team Sports Loans Statement Andrew is head of the sports club, and manages the inventory. Part of Andrew’s job is loaning footballs to people, and collecting those footballs once they have been used. At the start of the day, Andrew has \(r\) footballs in stock, and knows that \(p+q\) people will approach him over the course of the day. \(p\) people will request a football, while \(q\) people will... Least Common Ancestor (LCA) 2021-04-20T22:00:00+10:00 2021-12-13T11:01:10+11:00 https://monashaps.github.io//posts/lca/ Monash Programming Team Where is this useful? The Least Common Ancestor (LCA) data structure is useful wherever you have a directed graph where every vertex has out-degree \(\leq 1\). In more common terms, each vertex has a unique determined ‘parent’, or it is a root node, with no parent. The most common (and almost always only) example being a rooted tree. On these particular graphs, the LCA gives us a fast way to ... Primes and Factorization Techniques 2021-04-05T11:00:00+10:00 2021-04-05T21:02:26+10:00 https://monashaps.github.io//posts/factorization/ Monash Programming Team Why? Many number theoretic problems in competitive programming require analysing the factors or prime factors of a number. Here I’ll list a few techniques for finding these factors, and some techniques / properties involving the factors / prime factors of a number. Preliminaries First off, lets define our basic terms, then we can get into the interesting stuff. Definitions For a positive i...
diff --git a/index.html b/index.html
new file mode 100644
index 0000000..ef04138
--- /dev/null
+++ b/index.html
@@ -0,0 +1 @@
+ Monash Code Binder DataStructureLess Competition 2023 Editorial
Since the intention of the DataStructureLess Competition was to showcase some interesting/unique solve techniques, I thought it would be good to provide some editorial for all of the problems so ev...
Union Find / DSU
Where is this useful? In many problems, translating into a graph structure can prove helpful, as we can describe our problem in very abstract terms. Once you’ve translated into this graph structu...
Challenge Problems - 2021 Sem 2, Contest 1
Sports Loans Statement Andrew is head of the sports club, and manages the inventory. Part of Andrew’s job is loaning footballs to people, and collecting those footballs once they have been used. ...
Least Common Ancestor (LCA)
Where is this useful? The Least Common Ancestor (LCA) data structure is useful wherever you have a directed graph where every vertex has out-degree \(\leq 1\). In more common terms, each vertex ha...
Primes and Factorization Techniques
Why? Many number theoretic problems in competitive programming require analysing the factors or prime factors of a number. Here I’ll list a few techniques for finding these factors, and some techn...
Modular Arithmetic
What is it? Modular Arithmetic encompasses all sorts of theorems and optimizations surrounding the % operator in C and Python. As you’ll see in the related problems, modulo arithmetic is often ti...
Dynamic Programming
Why? Dynamic Programming (DP) is one of the most powerful tools you’ll come across in competitive programming. It normally turns up in about a third of all problems in a contest, in some form or a...
diff --git a/norobots/index.html b/norobots/index.html
new file mode 100644
index 0000000..c31a99c
--- /dev/null
+++ b/norobots/index.html
@@ -0,0 +1,11 @@
+
+
+
+ Redirecting…
+
+
+
+
+ Redirecting…
+ Click here if you are not redirected.
+
diff --git a/posts/dp/index.html b/posts/dp/index.html
new file mode 100644
index 0000000..3ec59c2
--- /dev/null
+++ b/posts/dp/index.html
@@ -0,0 +1,153 @@
+ Dynamic Programming | Monash Code Binder Dynamic Programming
Why?
Dynamic Programming (DP) is one of the most powerful tools you’ll come across in competitive programming. It normally turns up in about a third of all problems in a contest, in some form or another.
What is it?
Dynamic Programming is a general tactic centred around storing previous calculations so that you don’t need to recompute the outcome of certain functions. While this might seem like common sense, these precomputed calculations can hide themselves quite well.
Example
One of the most popular applications of DP in contests is in solving recurrence relations in a short amount of time. Consider the following problem, Cut Ribbon:
Polycarpus has a ribbon, its length is n. He wants to cut the ribbon in a way that fulfils the following two conditions:
- After the cutting each ribbon piece should have length a, b or c.
- After the cutting the number of ribbon pieces should be maximum.
Help Polycarpus and find the number of ribbon pieces after the required cutting.
Input
The first line contains four space-separated integers n, a, b and c (1 ≤ n, a, b, c ≤ 4000) — the length of the original ribbon and the acceptable lengths of the ribbon pieces after the cutting, correspondingly. The numbers a, b and c can coincide.
Output
Print a single number — the maximum possible number of ribbon pieces. It is guaranteed that at least one correct ribbon cutting exists.
Solution
Let’s try to create a function max_cuts(x)
, which tells us the maximum number of cuts for a ribbon of size x
. We want to compute max_cuts(n)
.
Since every possible cutting of the ribbon must start with a cut of size a
, b
, or c
, max_cuts(x)
must be either max_cuts(x-a)+1
, max_cuts(x-b)+1
or max_cuts(x-c)+1
(Assuming base case max_cuts(0) == 0
).
Therefore, we can define max_cuts
recursively in the following way:
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+
def max_cuts(x):
+ if x == 0:
+ return 0
+ best = 0
+ for cut in (a, b, c):
+ if x - cut >= 0:
+ best = max(best, max_cuts(x-cut) + 1)
+ if best == 0:
+ # Not possible to cut this length.
+ best = -100000
+ return best
+
+# We then just call
+max_cuts(n)
+
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+
int max_cuts(int x) {
+ if (x == 0) return 0;
+ int best = 0;
+ if (x - a >= 0) best = max(best, max_cuts(x - a) + 1);
+ if (x - b >= 0) best = max(best, max_cuts(x - b) + 1);
+ if (x - c >= 0) best = max(best, max_cuts(x - c) + 1);
+ // Not possible to cut this length.
+ if (best == 0) best = -100000;
+ return best;
+}
+
However, if you submit this as is, you will almost certainly get TLE, despite n <= 4000
. Why is this?
Lets inspect the call tree for max_cuts
with n=5
, and a, b, c = 1, 2, 3
:
As you can see, despite only taking 6 different values, max_cuts
is being called a bunch of times, which is unnecessary! Instead, we can save each value with DP!
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+
# DP[x] = max_cut(x) if computed, or -1.n
+DP = [-1] * (n+1)n
+
+def max_cuts(x):
+ if DP[x] != -1:n
+ # If already computed, just return the value!n
+ return DP[x]n
+ if x == 0:
+ DP[x] = 0m
+ return DP[x]m
+ best = 0
+ for cut in (a, b, c):
+ if x - cut >= 0:
+ best = max(best, max_cuts(x-cut) + 1)
+ if best == 0:
+ # Not possible
+ DP[x] = -100000m
+ return DP[x]m
+ # Set the DP value, so that future calls to max_cuts(x) just use DP[x].n
+ DP[x] = bestn
+ return DP[x]m
+
+max_cuts(n)
+
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+
// DP[x] = max_cut(x) if computed, or -1.n
+int DP[4002];n
+
+int max_cuts(int x) {
+ // If already computed, just return the value!n
+ if (DP[x] != -1) return DP[x];n
+ if (x == 0) {
+ DP[x] = 0;m
+ return DP[x];m
+ }
+ int best = 0;
+ if (x - a >= 0) best = max(best, max_cuts(x-a) + 1);
+ if (x - b >= 0) best = max(best, max_cuts(x-b) + 1);
+ if (x - c >= 0) best = max(best, max_cuts(x-c) + 1);
+ if (best == 0) {
+ // Not possible.
+ DP[x] = -100000;m
+ return DP[x];m
+ }
+ // Set the DP value, so that future calls to max_cuts(x) just use DP[x].n
+ DP[x] = best;n
+ return DP[x];m
+}
+
+int main() {
+ // Initiliase DP
+ for (int i=0; i<4002; i++) DP[i] = -1;
+ cout << max_cuts(n) << endl;
+}
+
This won’t TLE, because we only need to compute 4001 values, max_cuts(x)
for any x
from 0 to n
.
Related Problems
This post is licensed under GNU GPL V3 by the author. Comments powered by Disqus.
diff --git a/posts/dsless-editorial/index.html b/posts/dsless-editorial/index.html
new file mode 100644
index 0000000..6a5389d
--- /dev/null
+++ b/posts/dsless-editorial/index.html
@@ -0,0 +1,2653 @@
+ DataStructureLess Competition 2023 Editorial | Monash Code Binder DataStructureLess Competition 2023 Editorial
Since the intention of the DataStructureLess Competition was to showcase some interesting/unique solve techniques, I thought it would be good to provide some editorial for all of the problems so everyone can see some of the cool stuff on offer.
Each problem has been given a few hints, so you can hopefully have a stab at the solution even if you got stuck in contest, but a solution is also provided.
Binary 1
Hint 1
Simulating won’t be enough, because of the size of $i$. We need to somehow skip most of the previous values.
Hint 2
Note that the lengths of the binary numbers increase as we move along the sequence, in fact there are $2^k$ binary numbers of length $k+1$
Solution
Assuming you’ve read the previous two hints, we want to skip ‘blocks’ of binary numbers of equal length. Since these blocks at least double in size each time we can get rid of an exponential amount of numbers before our index. We can continue subtracting these larger and larger blocks until our index would be exceeded by the next block: a jump of size $(k+1) * 2^k$, which tells us that the value we are trying to find is within a binary number of length $k+1$.
Now we know that after dealing with the first $p$ digits (Which contains all binary strings with less than $k+1$ length), we are left to find the $i-p^{th}$ value in the sequence of binary strings of length $k+1$. But since all binary strings are the same length now, we know we’re actually looking at the $\frac{i-p}{k+1}^{th}$ binary string in that sequence! From here we can just do some indexing to get what we need.
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+
index = int(input())
+
+k = 0
+
+# While our index is not in the next block of binary strings of length k+1
+while index > (1 << k) * (k+1):
+ # Subtract our index to "offset" removing those binary strings
+ index -= (1 << k) * (k+1)
+ k += 1
+
+bit_length = k+1
+
+# 0 index, rather than 1-index.
+index -= 1
+# The jth binary string of length k+1 is 2^k + j (j is 0-indexed)
+skip_num = index // bit_length
+actual_num = (1 << k) + skip_num
+
+# Now the remaining index is simply the index of our singular binary number
+index = index % bit_length
+
+print(bin(actual_num)[2:][index])
+
Complexity $\mathcal{O}(\log_2(n))$
Binary 2
Hint 1
If you’ve solved Binary 1, we need to make a similar revelation about jumps.
Hint 2
Notice that in blocks of binary strings of equal size, the first bit is always 1, and every other bit is equal parts 0 and 1.
Solution
As the hints note, since we cycle through every binary number in a block, the numbers 0 and 1 appear the same amount, except for the first bit of every number, which is always 1.
Therefore for a string of $2^k$ binary numbers of size $k+1$, they contain $k*2^{k-1} + 2^k$ 1s.
Once we’ve dealt with everything except our block, rather than iterating through the final block, we can make use of this fact for “subblocks”.
For example, if our final number starts with “11”, it means that all binary strings of length $k+1$ starting with “10” are also included, so the last $k-1$ bits in all these numbers have an equal amount of 0s and 1s. If instead our final number starts with “10”, then we can simply recurse down. This is a bit hard to express in code but hopefully the logic above is clear.
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+
import sys
+
+index = int(input())
+
+total_ones = 0
+k = 0
+bit_length = 1
+
+while index > (1 << k) * (k+1):
+ index -= (1 << k) * (k+1)
+ if k == 0:
+ total_ones += 1
+ else:
+ # Same formula as $k*2^{k-1} + 2^k$
+ total_ones += (1 << (k - 1)) * (k + 2)
+ bit_length += 1
+
+# We've counted all 1s in the prior blocks.
+
+# 0 index.
+index -= 1
+# Our number has this bit_length.
+skip_num = index // bit_length
+actual_num = (1 << (bit_length-1)) + skip_num
+
+def rec(prev_ones, min_val, max_val, power):
+ # Recursive function to count all 1s in our current block.
+ # prev_ones: 1s to the left of our current bit (IE, if we've got to our binary number starting with `1101`, then there are 3 previous 1s, which will always be 1s for future binary strings)
+ # min_val: The minimum value of the search block
+ # max_val: The maximum value of the search block
+ # ^ These two will squish together by powers of 2
+ # power: The power of 2 we are searching for next (decreases by 1 each time)
+ global total_ones
+ if power < 0:
+ return
+ print(f"{min_val} {max_val} jump {1 << power} ones {prev_ones}", file=sys.stderr)
+ # min_val is always a power of 2
+ # max_val is either a power of 2 or smaller (Since it starts as the actual number we are looking for).
+ if min_val + (1 << power) <= max_val:
+ # Our number has a `1` in the nth bit
+ # We can skip to the right half and count all the 1s in the left!
+ # First, add all the static 1s.
+ total_ones += (1 << power) * prev_ones
+ print(f"{(1 << power) * prev_ones} ones added from previous indicies", file=sys.stderr)
+ if power > 0:
+ # And also add the ones which occur with 50% frequency.
+ total_ones += (1 << (power-1)) * power
+ print(f"{(1 << (power-1)) * power} extra ones in the left half added", file=sys.stderr)
+ # recurse
+ rec(prev_ones+1, min_val + (1 << power), max_val, power-1)
+ else:
+ # Our number has a `0` in the nth bit
+ # We are in the left half
+ if power > 0:
+ rec(prev_ones, min_val, max_val, power-1)
+
+print(f"Num lives in {actual_num}", file=sys.stderr)
+rec(1, 1 << (bit_length-1), actual_num, bit_length-2)
+
+index = index % bit_length
+# rec doesn't count the final number.
+total_ones += bin(actual_num)[2:][:index+1].count("1")
+
+print(total_ones)
+
Complexity $\mathcal{O}(\log_2(n))$
Binary 3
Hint 1
Note that for even jump sizes, the answer is the same if we just divide the jump size by 2. So you can assume the jump size is odd.
Hint 2
While we don’t quite have the same nice rule about equal numbers of 1s and 0s, there is still some structure in our bits. For example, not (assuming odd jump size) that the least significant bit always toggles between 0 and 1. What happens to the second bit, the third?
Solution
The revelation here is that if we look at the first $2^k$ numbers in the sequence, the $k$ least significant bits actually do have an equal number of 0s and 1s! There are a few nice proofs of this, and I’ll leave it as a task for the reader to attempt (Hint: Note that if the jump size is odd, the jump size and $2^k$ are coprime).
So, we can continue some similar logic here to get rid of the first $k$ bits to deal with (and since we are dealing with a power of 2 as input, we don’t have to worry about our ‘current’ block).
Now all we need to do is worry about the extra bits we missed. Note that the jump size determines how many extra bits there are. In general, we should have $\log_2(j)$ extra bits to deal with. But this means that there’s at most $\approx j$ unique values for these extra bits, so we can simply find all of these values and add them up separately, by recursing in blocks of size $2^p$.
While you can solve this recursively, using the fact that the number of values divisible by $j$ in the range $(a, b]$ is $\lfloor \frac{b}{j} \rfloor - \lfloor \frac{a}{j} \rfloor$, as team de
noted, you can also just use this formula between $[2^k\times a, 2^k\times (a+1))$ for every $a$ from 0 to $j$.
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+
import sys
+from math import log2, floor, ceil
+
+repeats, jump = list(map(int, input().split()))
+
+# Make jump odd.
+while jump % 2 == 0:
+ jump //= 2
+
+total_ones = 0
+# First, determine how many of the first k bits can be handled separately.
+handled_bit_length = floor(log2(repeats))
+
+# Handle the first handled_bit_length bits.
+total_ones += (1 << (handled_bit_length - 1)) * handled_bit_length if handled_bit_length >= 1 else 0
+
+print(f"Handled {total_ones} ones in the known block.", file=sys.stderr)
+
+# Now we need to count the occurence of all bits after this in the sequence.
+def rec(minval, maxval, cur_bit):
+ # rec checks for all 1s in cur_bit between minval and maxval.
+ global total_ones
+
+ # If cur_bit gets too small, we'll start double counting the bits we handled separately.
+ if cur_bit <= handled_bit_length - 1:
+ return
+ midway = minval + (1 << cur_bit)
+ if midway <= maxval:
+ # We have some space in the '1' of this cur_bit.
+ # Count how many numbers are within that are divisible by `jump`.
+ print(f"{maxval // jump - (midway-1) // jump} values in [{midway}, {maxval}], and all of these have a 1 in the {cur_bit}th bit.", file=sys.stderr)
+ total_ones += maxval // jump - (midway-1) // jump
+
+ # Recurse on the right branch
+ rec(midway, maxval, cur_bit-1)
+ # Recurse on the left branch
+ rec(minval, min(midway-1, maxval), cur_bit-1)
+
+rec(0, repeats*jump, ceil(log2(repeats * jump)))
+
+print(total_ones)
+
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+
i,j = map(int, input().split())
+
+while (j%2 == 0):
+ j //= 2
+
+
+k = 0
+iCopy = i
+while iCopy > 1:
+ k += 1
+ iCopy//=2
+
+s = [0]
+for x in range(1,j+1):
+ s.append(x%2 + s[x//2])
+
+# s[x] = # 1 bits in the binary representation of x.
+
+tot = 0
+for x in range(j + 1):
+ # count the occurences of s[x]*2^k up until s[x+1]*2^k.
+ tot += (min((i*(x+1)-1), i*j)//j - (i*x-1)//j)*s[x]
+
+# Add the number of 1s in the least k significant bits
+tot += k*i//2
+
+print(tot)
+
Complexity $\mathcal{O}(\log_2(i) + j)$
Coins 1
This is a classic problem
Hint 1
The bounds imply a logarithmic solution. What’s the base of the logarithm?
Hint 2
Something akin to binary search would be good, although the binary search is optimal for a usual search because there are 2 equally possible outcomes for a query (value is left than or greater than the tester, equality means we stop immediately)
How many possible outcomes can the seesaw have? How can we use this to design a faster search?
Solution
The solution is to recognise that we want our query to break the solution space into three parts, depending on the seesaw result. We can do this by weighing one third of the remaining coins against another third. Then:
- If the left side is heavier, we need only recurse on that third of the coins
- If the right side is heavier, we need only recurse on that third of the coins
- If the left and right side are equal, then the fake coin must not have been weighed, so we recurse on the remaining third of the coins.
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+
def solve(coins):
+ if len(coins) == 1:
+ return coins[0]
+ if len(coins) == 2:
+ print(f"TEST {coins[0]} | {coins[1]}")
+ res = input()
+ if res == "LEFT":
+ return coins[0]
+ elif res == "RIGHT":
+ return coins[1]
+ else:
+ raise ValueError()
+ amount = len(coins) // 3
+ coin_left = coins[:amount]
+ coin_right = coins[amount:2*amount]
+ coin_extra = coins[2*amount:]
+ print(f"TEST {' '.join(map(str, coin_left))} | {' '.join(map(str, coin_right))}")
+ res = input()
+ if res == "LEFT":
+ return solve(coin_left)
+ elif res == "RIGHT":
+ return solve(coin_right)
+ elif res == "EQUAL":
+ return solve(coin_extra)
+
+n = int(input())
+
+coins = list(range(1, n+1))
+
+print("ANS", solve(coins))
+
Complexity: $\mathcal{O}(\log_3(n))$
Coins 2
A similar idea for a problem, with some added intricacy - How do I recurse quickly to resolve where the 2 coins are?
Hint 1
Obviously since the setup is the same if we can place the 2 coins in separate piles, then we can simply apply the previous solution to solve within time.
So our solution needs to either:
- Recurse into a smaller problem with 2 fake coins
- Separate into two separate problems with a single fake coin each
Hint 2
We can’t quite immediately split into 3 evenly distributed problems because each seesaw option can feasibly be caused by two different configurations (For example, the left pile being bigger could be 2 in left, 0 elsewhere, or 1 in left and 1 unweighed).
Can we either:
- Change what we’re weighing so that this isn’t the case? or
- Provide additional weighing steps to disambiguate.
Solution
Following on from Hint 2, let’s follow these two options to two different solutions.
Option 1: Change what we query
Note that the fact that we have a third of the coins unweighed is the main cause of ambiguity. If there was a way to limit the size of the unweighed portion then our problems would mostly go away. So rather than splitting into thirds, lets do the original naive thing for coins 1, splitting in half, and only at most 1 coin will miss out from weighing. Then:
- If the seesaw goes LEFT, then all fake coins are in the left pile (or the additional unweighed)
- If the seesaw goes RIGHT, then all fake coins are in the right pile (or the additional unweighed)
- If the seesaw is EQUAL, then the additional coin cannot be fake. There must be a fake coin in each of the two weighed piles
This solution will have the maximum of $\log_2(n)$ and $2\log_3(n)$ queries ($2\log_3(n)$).
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+
def solve_1(coins):
+ # log_3(n)
+ if len(coins) == 1:
+ return coins[0]
+ if len(coins) == 2:
+ print(f"TEST {coins[0]} | {coins[1]}")
+ res = input()
+ if res == "LEFT":
+ return coins[0]
+ elif res == "RIGHT":
+ return coins[1]
+ else:
+ raise ValueError()
+ amount = len(coins) // 3
+ coin_left = coins[:amount]
+ coin_right = coins[amount:2*amount]
+ coin_extra = coins[2*amount:]
+ print(f"TEST {' '.join(map(str, coin_left))} | {' '.join(map(str, coin_right))}")
+ res = input()
+ if res == "LEFT":
+ return solve_1(coin_left)
+ elif res == "RIGHT":
+ return solve_1(coin_right)
+ elif res == "EQUAL":
+ return solve_1(coin_extra)
+
+def solve_2(coins):
+ # 2*log_3(n)
+ if len(coins) == 2:
+ return coins
+ amount = len(coins) // 2
+ coin_left = coins[:amount]
+ coin_right = coins[amount:2*amount]
+ coin_extra = coins[2*amount:]
+ print(f"TEST {' '.join(map(str, coin_left))} | {' '.join(map(str, coin_right))}")
+ res = input()
+ if res == "LEFT":
+ return solve_2(coin_left + coin_extra)
+ elif res == "RIGHT":
+ return solve_2(coin_right + coin_extra)
+ elif res == "EQUAL":
+ return solve_1(coin_left), solve_1(coin_right)
+
+
+n = int(input())
+
+coins = list(range(1, n+1))
+
+print("ANS", *solve_2(coins))
+
+
Option 2: Add a clarifying additional query.
This solution is more complicated, where we instead add an additional query to resolve the initial result.
Let’s call the state x-y-z if there are x fake coins in the left pile, y in the right, and z in the remaining unweighed coins
- If the original query is LEFT, then this is either 2-0-0 or 1-0-1.
- We can add an additional query comparing one half of extra to the other half of extra
- If the second query is left or right, it is 1-0-1 and we can recurse
- If the second query is equal, then it is 2-0-0 (or the extra coins were odd and the remaining unweighed is fake), and we can recurse
- Same rule applies for RIGHT, either 0-2-0 or 0-1-1.
- If the original query is EQUAL, then this is either 1-1-0 or 0-0-2.
- We can resolve this by weighing all of the unweighed coins against a combination of left and right coins.
- If the second query says the LEFT, then the unweighed coins are heavier and we recurse on the unweighed coins
- If the second query says the RIGHT, then the left/right pile coins have a fake coin each
- If the second query says EQUAL, then the left/right pile coins we haven’t chosen are the ones that must have a fake coin each
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
+95
+96
+97
+98
+99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+
def solve_1(coins):
+ # log_3(n)
+ if len(coins) == 1:
+ return coins[0]
+ if len(coins) == 2:
+ print(f"TEST {coins[0]} | {coins[1]}")
+ res = input()
+ if res == "LEFT":
+ return coins[0]
+ elif res == "RIGHT":
+ return coins[1]
+ else:
+ raise ValueError()
+ amount = len(coins) // 3
+ coin_left = coins[:amount]
+ coin_right = coins[amount:2*amount]
+ coin_extra = coins[2*amount:]
+ print(f"TEST {' '.join(map(str, coin_left))} | {' '.join(map(str, coin_right))}")
+ res = input()
+ if res == "LEFT":
+ return solve_1(coin_left)
+ elif res == "RIGHT":
+ return solve_1(coin_right)
+ elif res == "EQUAL":
+ return solve_1(coin_extra)
+
+def solve_2(coins):
+ # 2*log_3(n)
+ if len(coins) == 2:
+ return coins
+ amount = len(coins) // 3
+ if 2 * amount < len(coins) - 2*amount:
+ # This essentially just deals with 5.
+ amount += 1
+ coin_left = coins[:amount]
+ coin_right = coins[amount:2*amount]
+ coin_extra = coins[2*amount:]
+ print(f"TEST {' '.join(map(str, coin_left))} | {' '.join(map(str, coin_right))}")
+ res = input()
+ if res == "LEFT":
+ # Either 2-0-0
+ # or 1-0-1.
+ # Check by comparing half of extra against itself.
+ # Some base cases for the second test:
+ if len(coin_left) == 1:
+ return solve_1(coin_left), solve_1(coin_extra)
+ if len(coin_extra) == 1:
+ print(f"TEST {coin_left[0]} | {coin_extra[0]}")
+ res2 = input()
+ if res2 == "LEFT":
+ return solve_2(coin_left)
+ elif res2 == "RIGHT":
+ return solve_1(coin_left[1:]), solve_1(coin_extra)
+ elif res2 == "EQUAL":
+ # Since coin_left == 2
+ return solve_1(coin_left[:1]), solve_1(coin_extra)
+ return None
+ # Now the meat and potatoes
+ extra_amount = len(coin_extra) // 2
+ extra_left = coin_extra[:extra_amount]
+ extra_right = coin_extra[extra_amount:2*extra_amount]
+ extra_extra = coin_extra[2*extra_amount:] # read all about it
+ print(f"TEST {' '.join(map(str, extra_left))} | {' '.join(map(str, extra_right))}")
+ res2 = input()
+ if res2 == "LEFT":
+ # 1-0-1-0
+ return solve_1(coin_left), solve_1(extra_left)
+ elif res2 == "RIGHT":
+ # 1-0-0-1
+ return solve_1(coin_left), solve_1(extra_right)
+ elif res2 == "EQUAL":
+ # 2-0-0-0 (plus extra_extra)
+ return solve_2(coin_left + extra_extra)
+ elif res == "RIGHT":
+ # Either 0-2-0
+ # or 0-1-1.
+ # Check by comparing half of extra against itself.
+ # Some base cases for the second test:
+ if len(coin_right) == 1:
+ return solve_1(coin_right), solve_1(coin_extra)
+ if len(coin_extra) == 1:
+ print(f"TEST {coin_right[0]} | {coin_extra[0]}")
+ res2 = input()
+ if res2 == "LEFT":
+ return solve_2(coin_right)
+ elif res2 == "RIGHT":
+ return solve_1(coin_right[1:]), solve_1(coin_extra)
+ elif res2 == "EQUAL":
+ # Since coin_right == 2
+ return solve_1(coin_right[:1]), solve_1(coin_extra)
+ return None
+ extra_amount = len(coin_extra) // 2
+ extra_left = coin_extra[:extra_amount]
+ extra_right = coin_extra[extra_amount:2*extra_amount]
+ extra_extra = coin_extra[2*extra_amount:] # read all about it
+ print(f"TEST {' '.join(map(str, extra_left))} | {' '.join(map(str, extra_right))}")
+ res2 = input()
+ if res2 == "LEFT":
+ # 0-1-1-0
+ return solve_1(coin_right), solve_1(extra_left)
+ elif res2 == "RIGHT":
+ # 0-1-0-1
+ return solve_1(coin_right), solve_1(extra_right)
+ elif res2 == "EQUAL":
+ # 0-2-0-0 (plus extra_extra)
+ return solve_2(coin_right + extra_extra)
+ elif res == "EQUAL":
+ # Either 1-1-0 or 0-0-2
+ # Resolve by weighing all of extra against some subset of left+right
+ not_extra = (coin_left + coin_right)[:len(coin_extra)]
+ print(f"TEST {' '.join(map(str, coin_extra))} | {' '.join(map(str, not_extra))}")
+ res2 = input()
+ if res2 == "LEFT":
+ # 0-0-2
+ return solve_2(coin_extra)
+ elif res2 == "RIGHT":
+ return solve_1(coin_left), solve_1(coin_right)
+ elif res2 == "EQUAL":
+ # only 1-1-0 is possible, when n=5,
+ return solve_1(coin_left), solve_1(coin_right)
+
+
+n = int(input())
+
+coins = list(range(1, n+1))
+
+print("ANS", *solve_2(coins))
+
Complexity: $\mathcal{O}(2\log_3(n))$
Coins 3
The new style seesaw requires us to completely ignore past solutions and find 3 fake coins
Hint 1
The single coin version of the problem can be solved in $\mathcal{O}(\log_4(n))$ guesses.
The double coin version of the problem can be solved in $\mathcal{O}(\log_2(n))$ guesses.
Hint 2
If you’ve solved the previous two problems, this should really just be applying the same mantra - how can I make 1/2 guesses to completely disambiguate which pile of coins the fake coins lie in.
Solution
Let’s start off by coding in solve1
and solve2
, as there isn’t much interesting to this:
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
+95
+96
+97
+98
+99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+
def guess(c1, c2, c3):
+ print("TEST", " ".join(map(str, c1)), "|", " ".join(map(str, c2)), "|", " ".join(map(str, c3)))
+ return [
+ list(map(int, section.strip().split()))
+ for section in input().split(">")
+ ]
+
+def solve_1(coins):
+ if len(coins) == 1:
+ return coins[0]
+ elif len(coins) == 2:
+ res = guess(coins[:1], coins[1:], [])
+ assert len(res) == 3
+ if res[0][0] == 1:
+ return coins[0]
+ else:
+ return coins[1]
+ amount = (len(coins) + 1) // 4
+ coin1 = coins[:amount]
+ coin2 = coins[amount:2*amount]
+ coin3 = coins[2*amount:3*amount]
+ coin4 = coins[3*amount:]
+ res = guess(coin1, coin2, coin3)
+ assert len(res) != 3
+ if len(res) == 2:
+ # The heavier one is alone.
+ if res[0][0] == 1:
+ return solve_1(coin1)
+ if res[0][0] == 2:
+ return solve_1(coin2)
+ if res[0][0] == 3:
+ return solve_1(coin3)
+ else:
+ # All are equal, the remainder is the culprit.
+ return solve_1(coin4)
+
+def solve_2(coins):
+ if len(coins) == 2:
+ return coins[0], coins[1]
+ elif len(coins) == 3:
+ res = guess([coins[0]], [coins[1]], [coins[2]])
+ cur = []
+ if 1 in res[0]:
+ cur.append(coins[0])
+ if 2 in res[0]:
+ cur.append(coins[1])
+ if 3 in res[0]:
+ cur.append(coins[2])
+ return cur[0], cur[1]
+ elif len(coins) in [4, 5]:
+ res = guess([coins[0]], [coins[1]], [coins[2]])
+ assert len(res) != 3
+ if len(res) == 1:
+ return solve_2(coins[3:])
+ cur = []
+ if 1 in res[0]:
+ cur.append(coins[0])
+ if 2 in res[0]:
+ cur.append(coins[1])
+ if 3 in res[0]:
+ cur.append(coins[2])
+ if len(cur) == 1:
+ cur.append(solve_1(coins[3:]))
+ return cur[0], cur[1]
+
+ # At least 6, so 3*coin4 <= len
+ amount = (len(coins)+2) // 4
+
+ coin1 = coins[:amount]
+ coin2 = coins[amount:2*amount]
+ coin3 = coins[2*amount:3*amount]
+ coin4 = coins[3*amount:]
+ res = guess(coin1, coin2, coin3)
+ assert len(res) != 3 # 3 Distinct weights doesn't make sense with 2 coins
+ if len(res) == 2:
+ # There is an imbalance.
+ if len(res[0]) == 2:
+ # There are 2 heavy piles and 1 light pile
+ # 1-1-0
+ cur = []
+ if 1 in res[0]:
+ cur.append(solve_1(coin1))
+ if 2 in res[0]:
+ cur.append(solve_1(coin2))
+ if 3 in res[0]:
+ cur.append(solve_1(coin3))
+ return cur[0], cur[1]
+ else:
+ # There is 1 heavy pile and 2 light piles
+ # 2-0-0-0, or 1-0-0-1
+ if res[0][0] == 1:
+ weighted_first = coin1 + coin2 + coin3
+ elif res[0][0] == 2:
+ weighted_first = coin2 + coin3 + coin1
+ elif res[0][0] == 3:
+ weighted_first = coin3 + coin1 + coin2
+
+ weighted = weighted_first[:len(coin4)]
+ empty = weighted_first[len(coin4):2*len(coin4)]
+ res2 = guess(weighted, empty, coin4)
+
+ assert len(res2) == 2
+ if len(res2[0]) == 2:
+ return solve_1(weighted), solve_1(coin4)
+ else:
+ return solve_2(weighted)
+ else:
+ # 0-0-0-2
+ return solve_2(coin4)
+
Now, to solve the 3 coin case, let’s divide our coins into 3 piles of size $a$, plus the remainder.
Let’s do the case analysis for different outcomes of the weighing.
- If the outcome of the weighing has 3 distinct bands of weight (like
3 > 1 > 2
), then we know the heaviest pile has 2 fake coins, and the middle pile has 1 fake coin.- Final complexity: $1 + \log_4(a) + \log_2(a) = 3\log_4(a)$
- If the outcome of the weighing has 2 distinct bands of weight, with two heavier piles (
3 1 > 2
), then both heavy piles have 1 fake coin, and the remainder has 1 fake coin.- Final complexity: $1 + 2\log_4(a) + \log_4(n-3a)$
- If the outcome of the weighing has 2 distinct bands of weight, with two lighter piles (
3 > 1 2
), then the heavy pile has anywhere from 1 to 3 fake coins, and the remainder has anywhere from 0 to 2 fake coins.- This can simply be solved by recursing to find 3 coins in the heavy pile plus the remainder in $1 + T(n-2a)$
- If the outcome of the weighing has 1 distinct band of weight (all equal), then either all piles have a fake coin, or the remainder has all 3 fake coins
- We can disambiguate this by weighing the entire remainder against a subset of the weighed piles, giving a complexity of $2 + \text{max}(3\log_4(a), T(n-3a))$
However, you’ll notice the remainder causes some issues in the final case, and our logic can be made much simpler if we just make each weighed pile about $n/3$ in size. Then in the final case, all 3 fake coins being in the remainder is impossible!
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+
def solve_3(coins):
+ if len(coins) == 3:
+ return coins[0], coins[1], coins[2]
+ elif len(coins) == 4:
+ res = guess([coins[0]], [coins[1]], [coins[2]])
+ if len(res) == 1:
+ return coins[0], coins[1], coins[2]
+ else:
+ if 1 in res[1]:
+ return coins[1], coins[2], coins[3]
+ if 2 in res[1]:
+ return coins[0], coins[2], coins[3]
+ if 3 in res[1]:
+ return coins[0], coins[1], coins[3]
+ amount = len(coins) // 3
+ coin1 = coins[:amount]
+ coin2 = coins[amount:2*amount]
+ coin3 = coins[2*amount:3*amount]
+ coin4 = coins[3*amount:]
+
+ res = guess(coin1, coin2, coin3)
+ if len(res) == 3:
+ # 2-1-0-0
+ cur = []
+ if res[0][0] == 1:
+ cur.extend(solve_2(coin1))
+ if res[0][0] == 2:
+ cur.extend(solve_2(coin2))
+ if res[0][0] == 3:
+ cur.extend(solve_2(coin3))
+ if res[1][0] == 1:
+ cur.append(solve_1(coin1))
+ if res[1][0] == 2:
+ cur.append(solve_1(coin2))
+ if res[1][0] == 3:
+ cur.append(solve_1(coin3))
+ return cur[0], cur[1], cur[2]
+ elif len(res) == 2:
+ if len(res[0]) == 2:
+ # 1-1-0-1
+ cur = []
+ if 1 in res[0]:
+ cur.append(solve_1(coin1))
+ if 2 in res[0]:
+ cur.append(solve_1(coin2))
+ if 3 in res[0]:
+ cur.append(solve_1(coin3))
+ cur.append(solve_1(coin4))
+ return cur[0], cur[1], cur[2]
+ else:
+ # 3-0-0-0
+ # 2-0-0-1
+ # 1-0-0-2
+ cur = coin4
+ if res[0][0] == 1:
+ cur.extend(coin1)
+ if res[0][0] == 2:
+ cur.extend(coin2)
+ if res[0][0] == 3:
+ cur.extend(coin3)
+ # log_3(n)
+ return solve_3(cur)
+ else:
+ # 0-0-0-3 - not possible due to definition of amount.
+ # 1-1-1-0
+ return solve_1(coin1), solve_1(coin2), solve_1(coin3)
+
Complexity: $\mathcal{O}(3\log_4(n))$ guesses
Cutting Board 1
These next two problems invite you to think about optimal strategies in a game of cuts.
Hint 1
Try to classify some small games as one of the four outcomes, try to make some rules for combining 2 boards.
Hint 2
- Can the game ever have a strategy where the 1st player alyways wins?
- Is it just boards with length = width where the 2nd player alyways wins?
Solution
Let’s try to analyse the smallest few games, and in doing so make some rules for combining boards together.
We’ll call a game $2$ if the second player wins, $1$ if the first player wins, and $V$ or $H$ for Vaughn/Hazel winning always.
1 2 3 4 5 6 7 8 9 10 1 2 V 2 H 3 4 5 6 7 8 9 10
The $1\times 1$ game is super simple - the first player can’t move, so the second player always wins. For the $2\times 1$ and $1\times 2$ games, only one player has a move to make, so they always win.
Let’s look at a few more games.
1 2 3 4 5 6 7 8 9 10 1 2 V V V V V V V V V 2 H 2 2 3 H 2 4 H 5 H 6 H 7 H 8 H 9 H 10 H
First off, any $n\times 1$ or $1\times n$ game handles exactly the same as a $2\times 1$.
Next, for the $2\times 2$, note that whoever moves first will create two games that we’ve previously calculated they cannot win. Playing a game on two boards which individually the other player has a strategy to win is a loss for the starting player, because the responding player always has a winning move on both boards, provided they always play on the same board as the previous player’s move.
Additionally, for $3\times 2$ and $2\times 3$, note that the game will always become a combination of a $2\times 2$ and a $1\times 2$/$2\times 1$ game.
We’ve come up with the following two rules for cutting board (assuming that these 2 boards are the best the players can come up with). These rules also apply to Hazel when the outcomes are flipped.
There is one more rule that comes up when analysing $2\times 4$. Note that while Vaughn could split into $2\times 1$ and $2\times 3$, this would result in a loss (As our $2+H$ rule states). Instead, Vaughn can split the game into $2\times 2$ and $2\times 2$. Since both games are losing for the second player, Vaughn can just follow whatever board Hazel makes a move on, and Vaughn will always win the game (If Hazel makes a move on the first $2\times 2$ box, then Vaughn has a winning move on one of the resultant cutting boards).
With just these three rules in hand, we can actually fill the entire board:
1 2 3 4 5 6 7 8 9 10 1 2 V V V V V V V V V 2 H 2 2 V V V V V V V 3 H 2 2 V V V V V V V 4 H H H 2 2 2 2 V V V 5 H H H 2 2 2 2 V V V 6 H H H 2 2 2 2 V V V 7 H H H 2 2 2 2 V V V 8 H H H H H H H 2 2 2 9 H H H H H H H 2 2 2 10 H H H H H H H 2 2 2
Hopefully by now you’re noticing the pattern. A proof left for the reader is why these 2s appear in boxes of size $2^k$. (Hint: Think about the first value in the row that can be a 2
rather than a H
. What does it require in the values above it in the column? And what about the first value in the row that is a V
, what needs to precede the V
in the same row?)
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+
import math
+
+n, m = list(map(int, input().split()))
+
+l2n, l2m = math.floor(math.log2(n)), math.floor(math.log2(m))
+
+if l2n == l2m:
+ print("2nd Player")
+elif l2n > l2m:
+ print("Vaughn")
+else:
+ print("Hazel")
+
Cutting Board 2
Hint 1
If you’ve seen the solution for Cutting Board 1 - try a similar approach of mapping out the first few values in both dimensions. You should see a very different picture.
Hint 2
Notice that:
- Adding a cut will always take the game into $n$ copies of the same board, which will either be, $2$, V or H.
- Multiple games of $2$ are just $2$, Multiple games of V or H are just V or H.
As such, $2\times 2$, $2\times 3$, $2\times 5$ are essentially the same board, as far as this game is concerned. How is $2\times 4$ different?
Solution
As noted in the previous hint, let’s use the rules of combining boards to map out some smaller values:
1 2 3 4 5 6 7 8 9 10 1 2 V V V V V V V V V 2 H 2 2 V 2 V 2 V V V 3 H 2 2 V 2 V 2 V V V 4 H H H 2 H 2 H V 2 2 5 H 2 2 V 2 V 2 V V V 6 H H H 2 H 2 H V 2 2 7 H 2 2 V 2 V 2 V V V 8 H H H H H H H 2 H H 9 H H H 2 H 2 H V 2 2 10 H H H 2 H 2 H V 2 2
This table is a lot harder to decipher, but notice what the table looks like when I change the order of rows:
1 2 3 4 5 6 7 8 9 10 1 2 V V V V V V V V V 2 H 2 2 V 2 V 2 V V V 3 H 2 2 V 2 V 2 V V V 5 H 2 2 V 2 V 2 V V V 7 H 2 2 V 2 V 2 V V V 4 H H H 2 H 2 H V 2 2 6 H H H 2 H 2 H V 2 2 9 H H H 2 H 2 H V 2 2 10 H H H 2 H 2 H V 2 2 8 H H H H H H H 2 H H
We see strong bands of equal results. In a sense, all prime sized boards are equivalent, as are all boards of size 2 prime factors, and so on. Let’s continue this logic and permute the columns:
1 2 3 5 7 4 6 9 10 8 1 2 V V V V V V V V V 2 H 2 2 2 2 V V V V V 3 H 2 2 2 2 V V V V V 5 H 2 2 2 2 V V V V V 7 H 2 2 2 2 V V V V V 4 H H H H H 2 2 2 2 V 6 H H H H H 2 2 2 2 V 9 H H H H H 2 2 2 2 V 10 H H H H H 2 2 2 2 V 8 H H H H H H H H H 2
In general, the best strategy seems to be: Cut out a prime factor of your board size, and you are left with multiple boards that will be best for you.
Note that
- If this produces winning boards for the opposite team, there was no way for you to win.
- If this produces winning boards for your team, then you can win simply by following whichever game your opponent plays first on.
- If this produces winning boards for the second player, then you can win simply by following whichever game your opponent plays first on.
Therefore the solution boils down to finding the number of prime factors a number has.
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+
MAX_N = int(2e6)
+
+is_prime = [True] * (MAX_N+1)
+is_prime[0] = False
+is_prime[1] = False
+
+for jump in range(2, MAX_N+1):
+ if not is_prime[jump]: continue
+ for pos in range(2*jump, MAX_N+1, jump):
+ is_prime[pos] = False
+
+primes = [i for i, v in enumerate(is_prime) if v]
+
+def n_prime_factors(v):
+ n_factors = 0
+ for p in primes:
+ while v % p == 0:
+ v //= p
+ n_factors += 1
+ return n_factors
+
+n, m = list(map(int, input().split()))
+
+n_factors = n_prime_factors(n)
+m_factors = n_prime_factors(m)
+
+if n_factors > m_factors:
+ print("Vaughn")
+elif m_factors > n_factors:
+ print("Hazel")
+else:
+ print("2nd Player")
+
Complexity: $\mathcal{O}(n)$
Divisors 0
Hint 1
Is there a formula we could be using that simplifies the sum of the first $n$ natural numbers?
If so, how would we change this formula for modulo?
Hint 2
Note that since we take the modulo of every individual value, the modulo-ed sequence repeats every $m$ values, so rather than computing the entire sequence, we can compute the sum of the first $m$ values and, excluding the final $n \% m$ bit of the sequence, we can simply count the number of repetitions.
Solution
As noted in the hint, the sequence repeats if we look at for example $n=14$, $m=4$.
\[1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + 14\]after modulo by 4 becomes
\[1 + 2 + 3 + 0 + 1 + 2 + 3 + 0 + 1 + 2 + 3 + 0 + 1 + 2\]Note that the $1 + 2 + 3 + 0$ sum is repeated a bunch of times, except for the final value $ + 1 + 2$.
Using the triangle number formula, the sum $1 + 2 + 3 + 0$ is equal to $\frac{3\times (3+1)}{2} = 6$, and this sequence is repeated $\lfloor \frac{14}{m} \rfloor = 3$ times.
So the total sum is equal to $3 \times 6 + 1 + 2$, however this final bit can be computed as $\frac{(n \% m)\times((n \% m) + 1)}{2} = 3$
1
+2
+3
+4
+5
+6
+7
+8
+9
+
n, m = list(map(int, input().split()))
+
+triangle = (m * (m-1)) // 2
+
+total = triangle * (n // m)
+extra = n % m
+total += (extra * (extra + 1)) // 2
+
+print(total)
+
Divisors 1
Hint 1
Try to figure out a rule for what natural number $n$ will generate the value $a_b$ in the sequence.
Note that $1_a$ will always be generated by natural number $a$.
Hint 2
Hopefully you’ve figured out that number $a_b$ will be generated by the natural number $a \times b$.
As such, ordering by appearance in the sequence should just be ordering by $a\times b$, with some care needing to be taken when comparing $a_b$ with $c_d$ and $a\times b = c\times d$.
Solution
As mentioned in the hint, the value $a_b$ is generated by the natural number $a \times b$, and so ordering the individual values by the natural number which generates them should sort the sequence in order.
In the case where $a \times b = c \times d$, notice that the smaller divisor will always be included in the sequence first, so after comparing $a\times b$ against $c\times d$, we need only compare $a$ against $c$.
1
+2
+3
+4
+5
+6
+7
+8
+
nums = int(input())
+v_and_p = list(map(lambda x: list(map(int, x.split("_"))), input().split()))
+# Sort by (a*b, a) (And retain b so we can reconstruct the sequence)
+sort_keys = list(map(lambda x: (x[0]*x[1], x[0], x[1]), v_and_p))
+sort_keys.sort()
+
+formatted = list(map(lambda x: f"{x[1]}_{x[2]}", sort_keys))
+print(" ".join(formatted))
+
Divisors 2
Hint 1
Try to flip the problem on its head a bit and solve the case of counting how many 1s, 2s, 3s, etc. occur before the value you are looking for.
For example, 2 occurs 7 times before $3_5$.
Hint 2
For a natural number $n$, the value $a$ appears in the sequence before $n_1$ $\lfloor \frac{n}{a} \rfloor$ times.
This is all well and good for small $a$, but we can’t have a linear solution for this problem. You need to make use of the fact that for large $a$ (In particular, $a > \sqrt{n}$), the value of $\lfloor \frac{n}{a} \rfloor$ is always rather small (In particular $\lfloor \frac{n}{a} \rfloor \leq \sqrt{n}$)
Solution
To start with, let’s assume that we want to find the index of $n_1$ for some $n$ (The end of the sequence of divisors of $n$), since this will make our lives a bit easier.
Notice that for any value $a$, $a$ will occur in the sequence before $n_1$ $\lfloor \frac{n}{a} \rfloor$ times. Let’s graph this for a large enough $n$:
This graph has a lot of large unique values for $a \leq \sqrt{n}$, and a few smaller common values for $a \geq \sqrt{n}$ (You can argue this via pidgeonhole principle - there are $\sqrt{n}$ possible values)
As such, we can count the first $\sqrt{n}$ values ourselves, and then count “sections” of the graph that are of equal size up to and including $\sqrt{n}$ in height.
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+
import math, sys
+def val(inp):
+ special, noccurences = list(map(int, inp.split("_")))
+
+ # This occurs for number n*k.
+
+ nk = special * noccurences
+ # Find the index of (nk-1)_1 - Then we can just add the divisors of nk up until special.
+ nk -= 1
+
+ # Notice that nk // j can only be a few different values (nk, nk/2, nk/3 already is much smaller after 3 iterations)
+ # We can instead find, for i up until sqrt(n):
+ # All j such that nk//j = i
+ # Then simply compute nk//i for all remaining (small) i.
+
+ root = math.floor(math.sqrt(nk))
+
+ print("nk", nk, file=sys.stderr)
+ print("root", root, file=sys.stderr)
+
+ nums = 0 # We start at index 1.
+ for i in range(1, root+1):
+ # What j satisfy nk//j=i?
+ if i == 1:
+ nums += nk - nk//2
+ prev = nk // 2
+ continue
+ # Well, anything where i * j <= nk < (i+1)*j
+ # In other words, start at nk//(i+1)
+ # Ends when the previous barrier is hit
+ smallest_excl = nk // (i+1)
+ largest_incl = prev
+ prev = smallest_excl
+ nums += (largest_incl - smallest_excl) * i
+ print(f"nk // j = {i} for ({smallest_excl}, {largest_incl}]", file=sys.stderr)
+ # Now we need to find the rest
+ for i in range(1, root+1):
+ # Exclude the final entry for anything larger than special.
+ if nk // i <= root:
+ break
+ nums += nk // i
+
+ nk += 1
+
+ root = math.floor(math.sqrt(nk))
+
+ # Now we just need to add position for the final integer.
+ if special * special < nk:
+ # Simply count
+ for i in range(1, special + 1):
+ if nk % i == 0:
+ nums += 1
+ else:
+ # Count total
+ tot_turn = 0
+ for i in range(1, root+1):
+ if nk % i == 0:
+ tot_turn += 1
+ tot_turn *= 2
+ if root * root == nk:
+ tot_turn -= 1
+ for i in range(1, root+1):
+ if nk%i == 0 and nk // i > special:
+ tot_turn -= 1
+ elif nk%i == 0:
+ break
+ nums += tot_turn
+
+ if nk == 1:
+ # Previous doesn't work for base case
+ return 1
+ else:
+ return nums
+
+print(val(input()))
+
This however has a much more elegant solution, found by team de
in the competition. Looking at the graph again, the graph is entirely the same when flipped along the axis $y=x$. So rather than counting the $\leq\sqrt{n}$ and $\geq \sqrt{n}$ cases separately, simply take the $\leq \sqrt{n}$ part of the graph, and double it.
This value then only double counts in the $\sqrt{n} \times \sqrt{n}$ box in the bottom left, which we can then subtract:
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+
# Courtesy of `de`.
+a,b = map(int,input().split("_"))
+n = a*b
+m = int((n-1)**0.5)
+# number of values for i <= sqrt(n-1)
+tot = 0
+for i in range(1,m+1):
+ tot += (n-1)//i
+
+def numDivisor(i,n):
+ tot = 0
+ if i*i <= n:
+ for j in range(1,i+1):
+ if (n%j) == 0:
+ tot += 1
+ return tot
+
+ j = n/i
+ m = int(n**0.5)
+ numFacs = 0
+ smth = 0
+ for k in range(1,m+1):
+ if (n%k == 0):
+ numFacs += 1
+ if k < j:
+ smth += 1
+
+ numFacs *= 2
+ if m*m == n:
+ numFacs -= 1
+
+ return numFacs - smth
+
+
+# double the area, subtract the square in the bottom left (m*m), and then add the divisors of just n up until a.
+tot = 2*tot - m*m + numDivisor(a,n)
+print(tot)
+
Divisors 3
A bit of a departure from all other problems in the contest, this problem asks you to approximate a sequence efficiently and effectively.
This problem ended up being a bit of a bad fit for a competition because a guarantee of maximum error is not the same as a practical guarantee of maximum error. Additionally, the accuracy and magnitude of the result was rather restrictive and made the problem a bit more annoying than it should have been.
Additionally, I didn’t do enough due diligence on checking my results, which made my initial solution incorrect (albeit accidentally accurate enough for the judge miraculously)
Hint 1
To solve the first boundary ( $\ln(n)$ ), you can solve this with a single line of code. (Moreso just a formula, than a line of code)
Also, I forgot to notice this in competition, but you’ll likely need an external package for extra decimal precision, like Pythons decimal
package.
Hint 2
The problem bounds imply that $\sqrt{n}$ should somehow come into play. Is there a way we count the contributions of $\frac{1}{a}$ for $a <= \sqrt{n}$ differently from all other $\frac{1}{a}$?
Solution
First, let’s solve the first test set bound.
Notice that, just like in the previous problem, for an end value $n$, and $a \leq n$, the value $\frac{1}{a}$ will be in the sequence $\lfloor\frac{n}{a}\rfloor$ times.
Therefore we can over-estimate the contribution for $\frac{1}{a}$ in total as $\frac{1}{a} \times \frac{n}{a} = \frac{n}{a^2}$.
Summing over all $a$, we get the following sequence, which is a rather famous sequence:
\[\frac{n}{1^2} + \frac{n}{2^2} + \frac{n}{3^2} + \cdots + \frac{n}{n^2} = n (\frac{1}{1^2} + \frac{1}{2^2} + \frac{1}{3^2} + \cdots + \frac{1}{n^2}) \approx n \frac{\pi^2}{6}\]The error bound on the approximation of the sum of reciprocals is $\frac{1}{n}$, meaning that ignoring the error that removing the floor contributes, we are within $\frac{n}{n} = 1$ of the correct solution. However removing the floor can add as much as $\ln(n)$ to the result.
To solve the second test set bound, there was one intended solution, which didn’t end up actually ensuring the error bounds were met, and another solution that was found by team de
. We’ll start with the semi-faulty solution.
Solution 1 - Modifying the test set 1 sequence.
Notice that the estimation error from $\frac{1}{a}\lfloor \frac{n}{a} \rfloor$ to $\frac{n}{a^2}$ is $\frac{n \% a}{a^2}$.
Let’s look at the full error expression for $n=20$:
\[\text{err} = \frac{0}{1^2} + \frac{0}{2^2} + \frac{2}{3^2} + \frac{0}{4^2} + \frac{0}{5^2} + \frac{2}{6^2} + \frac{6}{7^2} + \frac{4}{8^2} + \frac{2}{9^2} + \frac{0}{10^2} + \frac{9}{11^2} + \frac{8}{12^2} + \frac{7}{13^2} + \frac{6}{14^2} + \frac{5}{15^2} + \frac{4}{16^2} + \frac{3}{17^2} + \frac{2}{18^2} + \frac{1}{19^2}\]Notice that there are bands of rather well behaved fractions, for example from denominator 11 to 19. In general there will be an arithmetic progression on the numerators between the denominators of $\frac{n}{a+1}$ and $\frac{n}{a}$. Let’s try creating an estimator for these kinds of sequences.
\[R := \frac{a + bc}{x^2} + \frac{a+b(c-1)}{(x+1)^2} + \frac{a+b(c-2)}{(x+2)^2} + \cdots + \frac{a}{(x+c)^2}\]This sequence would be easier to resolve if the numerators increased with the denominators, rather than the opposite direction, so let’s do a manipulation.
\[(x + \frac{a}{b} + c) \times (\frac{1}{x^2} + \frac{1}{(x+1)^2} + \frac{1}{(x+2)^2} + \cdots + \frac{1}{(x+c)^2}) - \frac{R}{b} = \frac{x}{x^2} + \frac{x + 1}{(x+1)^2} + \frac{x + 2}{(x+2)^2} + \cdots + \frac{x+c}{(x+c)^2}\]Both sequences above have well known approximations, shown below:
\[\frac{1}{1^2} + \frac{1}{2^2} + \frac{1}{3^2} + \cdots + \frac{1}{n^2} = \frac{\pi^2}{6} - \frac{1}{n} - [0, \frac{1}{(n)(n+1)}]\] \[\frac{1}{1^2} + \frac{2}{2^2} + \frac{3}{3^2} + \cdots + \frac{n}{n^2} = \frac{1}{1} + \frac{1}{2} + \frac{1}{3} + \cdots + \frac{1}{n} = \ln(n) + \gamma + \frac{1}{2n} - [0, \frac{1}{8n^2}]\]where $\gamma$ is a constant. Substituting this into the equation above gives:
\[(x + \frac{a}{b} + c) \times (\frac{\pi^2}{6} - \frac{1}{x+c} - [0, \frac{1}{(x+c)(x+c+1)}] - \frac{\pi^2}{6} + \frac{1}{x-1} + [0, \frac{1}{(x-1)(x)}]) - \frac{R}{b} = \ln(x+c) + \gamma + \frac{1}{2(x+c)} - [0, \frac{1}{8(x+c)^2}] - \ln(x-1) - \gamma - \frac{1}{2(x-1)} + [0, \frac{1}{8(x-1)^2}]\] \[(x + \frac{a}{b} + c) \times (\frac{1}{x-1} - \frac{1}{x+c} + [-\frac{1}{(x+c)(x+c+1)}, \frac{1}{(x-1)(x)}]) - \frac{R}{b} = \ln(\frac{x+c}{x-1}) + \frac{1}{2(x+c)} - \frac{1}{2(x-1)} + [-\frac{1}{8(x+c)^2}, \frac{1}{8(x-1)^2}]\]Which solving for $R$ gives us
\[R = (xb + a + bc) \times (\frac{1}{x-1} - \frac{1}{x+c}) - b \times (\ln(\frac{x+c}{x-1}) + \frac{1}{2(x+c)} - \frac{1}{2(x-1)})\]with an error bound at most $\frac{xb + a + bc}{(x-1)(x)} + \frac{b}{8(x-1)^2}$.
Let’s use this estimate for the denominators $11$ through to $19$. This has $b=1$, $x=11$, $c=8$, $a=1$:
\[R = (11 + 1 + 8) \times (\frac{1}{10} - \frac{1}{19}) - (\ln(\frac{19}{10}) + \frac{1}{38} - \frac{1}{20})\]which gives about $0.07$ off of the actual solution
Choosing $d$ from $1$ up until $m := \lfloor \sqrt{n} \rfloor$ we can look at the denominator range $\frac{n}{d}$ down to $\frac{n}{d+1}$.
This has $b = d$, $c = \lfloor\frac{n}{d}\rfloor - \lfloor\frac{n}{d+1}\rfloor - 1$, $x = \lfloor\frac{n}{d+1}\rfloor$, and $a = n \% \lfloor \frac{n}{d} \rfloor$.
\[R = (d \lfloor \frac{n}{d+1} \rfloor + (n \% \lfloor \frac{n}{d} \rfloor ) + d \times (\lfloor\frac{n}{d}\rfloor - \lfloor\frac{n}{d+1}\rfloor - 1)) \times (\frac{1}{\lfloor\frac{n}{d+1}\rfloor -1} - \frac{1}{\lfloor\frac{n}{d}\rfloor - 1}) - d\times (\ln(\frac{\lfloor\frac{n}{d}\rfloor - 1}{\lfloor\frac{n}{d+1}\rfloor-1}) + \frac{1}{2(\lfloor\frac{n}{d}\rfloor - 1)} - \frac{1}{2(\lfloor\frac{n}{d+1}\rfloor-1)})\] \[R = (d \times (\lfloor\frac{n}{d}\rfloor - 1) + (n \% \lfloor \frac{n}{d} \rfloor )) \times \frac{\lfloor\frac{n}{d}\rfloor - \lfloor\frac{n}{d+1}\rfloor}{(\lfloor\frac{n}{d}\rfloor - 1) \times (\lfloor\frac{n}{d+1}\rfloor - 1)} - d\times (\ln(\frac{\lfloor\frac{n}{d}\rfloor - 1}{\lfloor\frac{n}{d+1}\rfloor-1}) + \frac{\lfloor\frac{n}{d+1}\rfloor - \lfloor\frac{n}{d}\rfloor}{2(\lfloor\frac{n}{d}\rfloor - 1)\times (\lfloor\frac{n}{d+1}\rfloor-1)})\] \[R = (n - d) \times \frac{\lfloor\frac{n}{d}\rfloor - \lfloor\frac{n}{d+1}\rfloor}{(\lfloor\frac{n}{d}\rfloor - 1) \times (\lfloor\frac{n}{d+1}\rfloor - 1)} - d\times (\ln(\frac{\lfloor\frac{n}{d}\rfloor - 1}{\lfloor\frac{n}{d+1}\rfloor-1}) + \frac{\lfloor\frac{n}{d+1}\rfloor - \lfloor\frac{n}{d}\rfloor}{2(\lfloor\frac{n}{d}\rfloor - 1)\times (\lfloor\frac{n}{d+1}\rfloor-1)})\]Although in practice I found
\[R = 1 - d \times \ln(\frac{d+1}{d})\]A relatively good and simple estimator for the above. (But the solution will use the lengthy approximation)
What is the error in this approximation? Well, there ends up being lots of cancellations in errors, since we are combining together lots of chained approximations, and so what was a positive error in the previous step now becomes the same negative error (this is not true for all error, for example some of the error in the harmonic approximation, but it is true for some).
Unless I’ve screwed something up (very possible) the total error ends up being a small factor of $\frac{1}{\sqrt{n}}$. This seems to atleast be true in practice.
For the values $\frac{n \% c}{c^2}$ for $c \leq \sqrt{n}$, we can just compute those manually.
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+
import sys
+import math
+from decimal import Decimal, getcontext
+
+getcontext().prec = 50
+
+n = int(input())
+
+# Subtract 1 for the trail off
+res = Decimal(n) * Decimal(math.pi) * Decimal(math.pi) / Decimal(6) - 1
+
+print("Pretty good approximation:", res, file=sys.stderr)
+
+ceil = min(n, int(1e6))
+
+# Now, we need to reduce by a%d/d^2 for all d <= a.
+for d in range(1, n // ceil):
+ d = Decimal(d)
+ smol = Decimal(n // (d + 1))
+ beeg = Decimal(n // d)
+ first_part = d * smol + (n % beeg) + d * (beeg - smol - 1)
+ second_part = Decimal(1) / Decimal(smol - 1) - Decimal(1) / Decimal(beeg - 1)
+ third_part = Decimal.ln((beeg - 1) / (smol - 1)) + Decimal(1) / (2 * (beeg - 1)) - Decimal(1) / (2 * (smol - 1))
+
+ reduction = first_part * second_part - d * third_part
+ # print(f"1/{n//d}^2 + ... + {n//(d+1)}/{n//(d+1)}^2 = {reduction}")
+ res -= reduction
+
+# Below sqrt(a), we can manually subtract the value
+for d in range(2, ceil):
+ res -= Decimal(n % d) / Decimal(d*d)
+
+print("Better:", res, file=sys.stderr)
+
+print(res)
+
Solution 2 - Other approximations
This solution was found by team de
in competition.
Rather than sticking with the $\frac{n\pi^2}{6}$ approximation, this solution instead goes back to the original sequence and looks at it with a new viewpoint:
Let’s collect all of the $\frac{1}{1}s$, $\frac{1}{2}s$, and so on, in distinct columns, where the height of the column represents how many times that fraction is used.
We can sum the columns before and after $m = \lfloor \sqrt{n} \rfloor$ differently.
For those before $m$, we can simply find each column’s contribution by adding $a \times \lfloor\frac{n}{a}\rfloor$. For those after $m$, notice that $\lfloor \frac{n}{a} \rfloor$ will only take at most $m+1$ unique values (Since $\lfloor \frac{n}{m} \rfloor \leq m+1$), and in fact. This means that if, rather than summing by column, we instead sum by row, we’ll have only $m$ sets of values to compute, rather than $n-m$.
It’s worth noting that before $m$, we have $m$ distinct columns, and so after $m$, we have $m$ distinct rows (subject to off by one issues)
What do our rows of the graph look like? Well, using the previous image, every row (From column $m$ onwards), will be a sum of consecutive reciprocals up until some point. For example, for $n=9$, $m=3$, we have:
\[\frac{1}{3} +\] \[\frac{1}{3} + \frac{1}{4} +\] \[\frac{1}{3} + \frac{1}{4} + \frac{1}{5} + \frac{1}{6} + \frac{1}{7} + \frac{1}{8} + \frac{1}{9}.\]Now each of these rows we can use the approximation $\frac{1}{1} + \frac{1}{2} + \frac{1}{3} + \ldots + \frac{1}{n} = \ln(n) + \gamma + \frac{1}{2n} + \mathcal{O}(\frac{1}{n^2})$.
Applying this gives us:
\[\ln(\frac{3}{2}) + \frac{1}{6} - \frac{1}{4} + \mathcal{O}(\frac{1}{m^2}) +\] \[\ln(\frac{4}{2}) + \frac{1}{8} - \frac{1}{4} + \mathcal{O}(\frac{1}{m^2}) +\] \[\ln(\frac{9}{2}) + \frac{1}{18} - \frac{1}{4} + \mathcal{O}(\frac{1}{m^2})\]So we can use this to solve the problem with a total error bound of $\mathcal{O}(\frac{m}{m^2}) = \mathcal{O}(\frac{1}{m})$!
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+
# Rephrased version of team `de`s solution.
+from decimal import Decimal, getcontext
+
+getcontext().prec = 50
+
+n = int(input())
+
+m = max(1,int(n**0.5))
+m1 = n//m
+
+total = Decimal("0")
+
+# Handle the first m columns
+for i in range(1, m+1):
+ total += Decimal(n//i)/Decimal(i)
+# Handle the remaining m1 rows
+for i in range(1, m1+1):
+ total += Decimal.ln(Decimal(n//i) / Decimal(m)) + Decimal(1) / Decimal(2 * (n//i)) - Decimal(1) / Decimal(2 * m)
+
+print(total)
+
Lights 1
Hint 1
Simulating the problem takes $\mathcal{O}(n\ln(n))$ time. Too much - there actually aren’t many lights that will be turned on, and we can generate them in a neat sequence.
Hint 2
Consider a (faulty) proof that no light should be turned on. What is wrong with it, and what does this tell us about the solution?
Consider any light $n$. Take any factor of $n$, call it $a$. Note that $\frac{n}{a}$ will be another distinct factor of $n$ - This is true for all $a$ we could have chosen. Since this is the case (every factor has a unique pair), the total number of factors of $n$ is even. Therefore light $n$ is off.
Solution
The issue with the proof given in Hint 2, is that for square numbers, the “pairing” maps the square root of a number with itself. Take $36$ for example, the divisors $1, 2, 3$ are paired with $36, 18, 12$ respectively, but $6$ is its own pair.
In fact, square numbers are the only numbers for which the proof given in Hint 2 doesn’t work, for this very reason. So the problem really boils down to counting how many square numbers are less than $n$. We can do this easily by simply returning $\lfloor \sqrt{n} \rfloor$!
1
+2
+3
+4
+5
+6
+
import math
+
+n = int(input())
+
+print(math.floor(math.sqrt(n)))
+
+
Lights 2
Hint 1
Note importantly that if Robot $a$ flicks a light switch, then Robot $a-1$ also flicks the same switch.
Hint 2
Consider the first $2^i$ lights. How many have been flicked once? twice? three times?
Solution
Using Hint 1, what we really need to find are the lights which are flicked on by Robot 1, but not Robot 2 (those that are flicked once), the lights which are flicked on by Robot 3, but not Robot 4 (those that are flicked thrice), and so on.
The lights that are flicked on by Robot 1, but not Robot 2 are those which are divisible by $2^0=1$, but are not divisible by $2^1=2$.
For the first $n$ lights, exactly $\lfloor \frac{n+1}{2} \rfloor$ of them will satisfy this rule.
The lights that are flicked on by Robot 3, but not Robot 4 are those which are divisible by $2^2=4$, but are not divisible by $2^3=8$. If we floor divide $n$ by $4$, and call this $m$, there are $m$ numbers divisible by $4$. Divide all these numbers by $4$. Now the question is simply how many of these are divisible by $2$, rather than divisible by $8$! So this is just the same as the first Robot question.
In general, the number of odd-flicked lights will be:
\[\lfloor \frac{n+1}{2} \rfloor + \lfloor \frac{\lfloor \frac{n}{4} \rfloor + 1}{2} \rfloor + \lfloor \frac{\lfloor \frac{n}{16} \rfloor + 1}{2} \rfloor + \lfloor \frac{\lfloor \frac{n}{64} \rfloor + 1}{2} \rfloor + \ldots\]Until this flooring starts giving 0 terms.
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+
n = int(input())
+
+total = 0
+cur_divisor = 1
+while True:
+ div_by_divisor = n // cur_divisor
+ not_div_by_2 = (div_by_divisor+1)//2
+ if not_div_by_2 == 0:
+ break
+ total += not_div_by_2
+ cur_divisor *= 4
+
+print(total)
+
Lights 3
Hint 1
Assuming you’ve solved Lights [I], this shouldn’t be too much of a stretch.
If a lighter $n$ is on in this configuration, what does it tell us about the divisors of $n$?
Hint 2
This problem statement counts the number of odd divisors of a number. For an odd number, how does this relate to the number of total divisors? For a number which has a prime factorisation including $2^i$, how does this relate to the number of total divisors?
Solution
One important tool we can use for this problem is the prime factorisation of a number. Take $12$ for example, it has a prime factorisation of $2^23^1$. Note that any divisor of 12 is created simply by setting the power of $2$ to be anything from $0, 1, 2$, and the power of $3$ to be anything from $0, 1$. ($1 = 2^03^0$, $6 = 2^13^1, \ldots$).
In general, if your prime factorisation is $a_1^{a_2}b_1^{b_2}c_1^{c_2}\cdots$, then your number has $(a_2+1)(b_2+1)(c_2+1)\cdots$ divisors, to account for all choices of the indicies.
Now, for odd numbers, any divisor is an odd divisor, so the same theory applies - only square numbers work.
But what about for evens? Take some number $n = 2^i3^j5^k$. This number has $(i+1)(j+1)(k+1)$ divisors, but the number of odd divisors is just the number of divisors where we picked the power of $2$ to be $2^0$.
Therefore the number of odd divisors of $n$ is $(j+1)(k+1)$. In other words, its the number of divisors of the odd number $3^j5^k$, which must be a square number.
So the only lights that should be on, are odd square numbers, and odd square numbers times a power of two.
Notice however, that since $2 \times 2$ is itself a square number, we can actually count all of the above numbers by simply counting all square numbers, and all square numbers times plain old 2. Take $2^3 \times 5^2$ for example, we can write this instead as $2 \times 10^2$.
We can count the number of squares, and the number of numbers which are two times a square simply with
\[\lfloor \sqrt{n} \rfloor + \lfloor \sqrt{\frac{n}{2}} \rfloor\] 1
+2
+3
+4
+5
+6
+
import math
+
+n = int(input())
+
+total = math.floor(math.sqrt(n)) + math.floor(math.sqrt(n//2))
+print(total)
+
Lights 4
Hint 1
The logic used in Lights 3 around how many divisors a number has will remain useful here:
In general, if your prime factorisation is $a_1^{a_2}b_1^{b_2}c_1^{c_2}\cdots$, then your number has $(a_2+1)(b_2+1)(c_2+1)\cdots$ divisors.
Hint 2
This problem needs a rather sophisticated prime counting function.
Solution
Let’s use the rule given in Hint 1 to try to come up with a way of figuring out if a light is on.
Since the number of divisors is equal to $(a_2+1)(b_2+1)(c_2+1)\cdots$, the number of divisors will be prime only when:
- There is a single prime divisor of the number (Since $(a_2+1)(b_2+1)$ is already non-prime), and
- $a_2+1$ is prime.
In other words, the prime factorisation of $n$ must be $p^i$, where $i+1$ is prime.
Now, we could compute this linearly using a prime sieve, however we need to be a bit faster than this. There’s actually a batched way that we could solve this.
Let’s first counting the number of values before $n$ which are represented as $p^1$ - This is just the number of primes before $n$. Next, we’ll count the number of values before $n$ which are represented as $p^2$ - This is just the number of primes that appear before $\sqrt{n}$ (Since squaring the left side gives a number we are looking for, and squaring the right side gives $n$).
In general, if $\pi$ is the prime counting function ($\pi(n)$ = number of primes at or before $n$), then we need to compute
\[\pi(n) + \pi(n^{1/2}) + \pi(n^{1/4}) + \pi(n^{1/6}) + \pi(n^{1/10}) + \cdots\]Now we just need a fast prime counting function, luckily the wikipedia page for prime counting functions has some tools we can use to make a faster prime counting function, in particular following a link to The Meissel Lehmer Algorithm - you can see that there exists an optimised version that solves the problem in $\mathcal{O}(n^\frac{2}{3})$ time, however for our purposes we can just use some simple rules from the Meissel Lehmer algorithm and makes something sublinear.
The primary thing to note from the mention of the algorithm in the prime counting function page, and, the main page for the algorithm, is that for our purposes, picking $y = \sqrt{n}$ and $n = \pi(y)$, then computing $\pi(m) = \phi(m, n) + n - 1 - P_2(m, n)$ is easy, because $P_2(m, n)$ is 0.
So all that’s left is simply to implement the recursion of $\phi$ efficiently. We can use a sieve up to $10^6$ for fast computation for small numbers, and for larger results simply defer to recursion:
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+
import sys
+import math
+
+sys.setrecursionlimit(int(1e5))
+
+n = int(input())
+
+# pi(n) + pi(sqrt(n)) + pi(n^1/4) + ...
+
+prime_limit = int(3e6)
+
+is_prime = [True] * (prime_limit+1)
+pi = [0] * (prime_limit+1)
+primes = []
+is_prime[0] = False
+is_prime[1] = False
+for x in range(2, prime_limit+1):
+ pi[x] = pi[x-1]
+ if not is_prime[x]: continue
+ pi[x] += 1
+ primes.append(x)
+ for pos in range(2*x, prime_limit+1, x):
+ is_prime[pos] = False
+
+def phi(m, n):
+ if m <= prime_limit and pi[m] <= n:
+ return 1
+ if n == 0:
+ return math.floor(m)
+ return phi(m, n-1) - phi(m//primes[n-1], n-1)
+
+def fast_prime(n):
+ m = n
+ y = math.floor(math.sqrt(m))
+ n = pi[y]
+ return phi(m, n) + n - 1
+
+total = 0
+for x in range(1, math.floor(math.log2(n)) + 1):
+ if is_prime[x+1]:
+ total += math.floor(fast_prime(math.floor(math.pow(n, 1/x))))
+ print(x, total, file=sys.stderr)
+print(total)
+
One other optimisation that can be made is noticing that the recursion tree is often quite long with a lot of small branches (At some stage if dividing $n$ by any large prime $p$ will give you the base case, then we can use binary search to find the first prime which won’t hit the base case)
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+
import sys
+import math
+
+sys.setrecursionlimit(int(1e5))
+
+n = int(input())
+
+# pi(n) + pi(sqrt(n)) + pi(n^1/4) + ...
+
+prime_limit = int(3e6)
+
+is_prime = [True] * (prime_limit+1)
+pi = [0] * (prime_limit+1)
+primes = []
+is_prime[0] = False
+is_prime[1] = False
+for x in range(2, prime_limit+1):
+ pi[x] = pi[x-1]
+ if not is_prime[x]: continue
+ pi[x] += 1
+ primes.append(x)
+ for pos in range(2*x, prime_limit+1, x):
+ is_prime[pos] = False
+
+def phi(m, n):
+ if m <= prime_limit and pi[m] <= n:
+ return 1
+ if n == 0:
+ return math.floor(m)
+ # Try binary searching through a bunch of the easy to solve stuff.
+ if n > 50 and m > prime_limit and m//primes[n-1] <= prime_limit and 2*pi[m//primes[n-1]] <= n-1:
+ hi = n
+ lo = 10
+ while hi - lo > 2:
+ mid = (hi + lo) // 2
+ new_m = m//primes[mid]
+ if new_m <= prime_limit and 2*pi[new_m] <= mid:
+ # We can go lower
+ hi = mid + 1
+ else:
+ # We can't go this low
+ lo = mid + 1
+ # Skip from n to mid in n-mid steps, since all deductions will just be -1.
+ return phi(m, mid) - (n-mid)
+ return phi(m, n-1) - phi(m//primes[n-1], n-1)
+
+def fast_prime(n):
+ m = n
+ y = math.floor(math.sqrt(m))
+ n = pi[y]
+ return phi(m, n) + n - 1
+
+total = 0
+for x in range(1, math.floor(math.log2(n)) + 1):
+ if is_prime[x+1]:
+ total += math.floor(fast_prime(math.floor(math.pow(n, 1/x))))
+ print(x, total, file=sys.stderr)
+print(total)
+
Misc 0
Hint 1
It might be first good to simplify the fraction given to you, and seeing what you can do with this information.
Hint 2
If the simplified fraction of the problem is $\frac{c}{d}$, then at every integer time you’ll actually see all values of $\frac{x}{d}$ around the circle. So what does $d$ tell us about whether we hit the other side?
Solution
To quickly prove the result of Hint 2, notice that if the simplified fraction is $\frac{c}{d}$, then we know that $c$ and $d$ are coprime, in other words $\text{gcd}(c, d) = 1$. Then there exists some values $x$ and $y$ such that $cx + dy = 1$. Consider where we will be after $x$ seconds. We’ll be at $\frac{cx}{d} = \frac{1 - dy}{d} = \frac{1}{d} - y = \frac{1}{d}$ rotation around the circle (If $x$ is negative, just keep adding $d$ seconds until it is positive and you’ll get the same result). So in $x$ seconds we can move $\frac{1}{d}$ around the circle, and so in $x\times a$ seconds we can move to $\frac{a}{d}$ around the circle for any integer $a$. Hopefully it is relatively clear that for a simplified fraction of $\frac{c}{d}$, any rotation not expressible as $\frac{a}{d}$ is not possible after an integer amount of seconds.
Now, all we need to determine is whether we hit the opposite side of the circle, $\frac{1}{2}$. This is only possible (and always possible) if $\frac{1}{2}$ is expressible as $\frac{a}{d}$ for some $a$.
Which hopefully you can see is always possible if $d$ is divisible by $2$.
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+
a, b = list(map(int, input().split()))
+
+# Simple lcm work.
+for x in range(2, min(a, b)):
+ while a % x == 0 and b % x == 0:
+ a //= x
+ b //= x
+
+# 1/b is the jump size.
+if b % 2 == 0:
+ print("Other axis!")
+else:
+ print("Free!")
+
Misc 1
Hint 1
This problem is best viewed through the lens of recursion. Your recursion will likely need to look back at all previous values (I.E., parens(4)
can be written as some combination of parens(3)
, parens(2)
, parens(1)
, parens(0)
)
Hint 2
Think about all possible parenthesis strings of containing $n$ closed parentheses. Each of these valid strings must start with an open parenthesis, which is closed at some point. What do I know about the strings in between these two parentheses, as well as after these two parentheses?
Solution
To answer Hint 2, the inside string and following string must both represent valid parenthesis strings!
Therefore, we can construct a valid parenthesis string of length $n$ by deciding:
- How many parenthesis will occur inside the first closed parenthesis, call it $a$
- What is a valid parenthesis string of length $a$ to use inside
- What is a valid parenthesis string of length $n-a-1$ to use outside
And this informs our recursive counting function. To compute parens(n)
, simply:
- Iterate for all $a$ from 0 to $n-1$
- Compute
parens(a) * parens(n-1-a)
- Add to the total and return the sum.
We just need to add some modular arithmetic to the solution and we are done:
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+
import sys
+
+sys.setrecursionlimit(int(1e5))
+
+MOD = int(1e9+7)
+
+n = int(input())
+
+DP = [None] * int(1e3 + 5)
+
+# How many parens patterns are there?
+def parens(n):
+ if DP[n] is not None:
+ return DP[n]
+ if n <= 1:
+ return 1
+ total = 0
+ for a in range(n):
+ # There are a parens in the first pattern
+ total += (parens(a) * parens(n-a-1)) % MOD
+ total %= MOD
+ DP[n] = total
+ return total
+
+print(parens(n))
+
Note - These are called the Catalan Numbers, and I was going to include many more problems featuring them originally. If you’re looking for a beautifully unique proof, look at Betrand’s Ballot Theorem, a generalisation of the Catalan numbers.
Misc 2
I guess I wasn’t thinking too much when I wrote this problem since it includes a data structure, but I think I count trees as more math than data structure, they are simply too fundamental :)
Hint 1
Try to think about the contributions on the left and right side of the removed edge separately (as well as that contributed by the edge itself separately). These three values when combined give you the answer.
Hint 2
The easiest of the three values to calculate is the amount contributed by the removed road itself.
This is simply the roads value, times the number of nodes on the left side of the road, times the number of nodes on the right side of the road. This is because left times right gives you the number of paths which cross the road.
The computation of the other two values (left road contributions, right road contributions) aren’t actually that much more complicated than above.
Solution
Let’s take a graph and remove some edge in the middle. First, we’ll try counting all contributions on the right side of the removed edge.
In fact, let’s be even more specific - let’s count the contributions on the right side of the removed edge, originating from paths starting at vertex $a$.
Counting all the paths, you’ll notice that “leaf” edges only contribute once, whereas the adjacent edges are counted multiple times - once for the shared node, and then once each for each adjacent leaf edge:
We can write a recursive function to compute how many times each edge is counted, simply by counting how many nodes live below the edge. After computing this we can easily calculate the the total contribution by summing the contributed amount over each edge. Let’s call this result sumRight
.
Notice in our workings above, the location of $a$ never really mattered. The logic for every node on the left side of the edge is the same. As a result, the total contribution of all right edges is simply sumRight * nodesLeft
, where nodesLeft
is the amount of nodes on the left side of the removed edge.
We can do the exact same process with the left edges, and we’re done!
Note: My original solution was written when I was planning to make this a query problem (Exact same problem, but rather than a single removed edge, we can think of $10^5$ possible removed edges, and what each of these removals would do for the graph), so my solution is over-engineered and hard to understand. I’ve also included team de
s approach which does what we outline above in a much simpler manner with a tree search centered at the removed edge, rather than fixing the tree structure at an arbitrary node, like my solution does.
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+
N = int(input())
+from collections import defaultdict
+edges = []
+for i in range(N-1):
+ edges.append(list(map(int,input().split())))
+
+x = int(input())
+
+dedge = edges[x-1]
+
+G = defaultdict(list)
+
+for e in edges:
+ if e[0] != dedge[0] or e[1] != dedge[1]:
+ G[e[0]].append([e[1],e[2]])
+ G[e[1]].append([e[0],e[2]])
+
+
+
+def distance_from(node):
+ dist = {node: 0}
+ visited = {node: True}
+
+ def search(n):
+ for e in G[n]:
+ if e[0] not in visited:
+ visited[e[0]] = True
+ dist[e[0]] = dist[n] + e[1]
+ search(e[0])
+ search(node)
+
+ return dist
+
+D1 = distance_from(dedge[0])
+D2 = distance_from(dedge[1])
+
+len1 = len(D1) # number of nodes
+len2 = len(D2)
+sum1 = sum(D1.values()) # sum of contributions
+sum2 = sum(D2.values())
+
+s = sum1*len2 + sum2*len1 + dedge[2]*len1*len2
+
+print(s)
+
+
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
+95
+96
+97
+
n = int(input())
+
+adj_list = [[] for _ in range(n)]
+par = [None]*n
+
+roads = []
+for i in range(n-1):
+ i, j, d = list(map(int, input().split()))
+ adj_list[i-1].append((j-1, d))
+ adj_list[j-1].append((i-1, d))
+ roads.append((i-1, j-1, d))
+
+def dfs(root):
+ for child, distance in adj_list[root]:
+ if child == par[root]: continue
+ par[child] = root
+ dfs(child)
+
+dfs(0)
+
+_num_paths_below = [None]*n
+_num_paths_above = [None]*n
+_sum_paths_below = [None]*n
+_sum_paths_above = [None]*n
+
+def num_paths_below(i):
+ if _num_paths_below[i] != None:
+ return _num_paths_below[i]
+ cur = 1
+ for j, distance in adj_list[i]:
+ if par[i] == j: continue
+ cur += num_paths_below(j)
+ _num_paths_below[i] = cur
+ return cur
+
+def sum_paths_below(i):
+ if _sum_paths_below[i] != None:
+ return _sum_paths_below[i]
+ cur = 0
+ for j, distance in adj_list[i]:
+ if par[i] == j: continue
+ cur += sum_paths_below(j) + distance * num_paths_below(j)
+ _sum_paths_below[i] = cur
+ return cur
+
+def num_paths_above(i):
+ if _num_paths_above[i] != None:
+ return _num_paths_above[i]
+ if par[i] == None:
+ return 1
+ cur = num_paths_above(par[i]) + 1
+ for child, distance in adj_list[par[i]]:
+ if child == i: continue
+ if child == par[par[i]]: continue
+ cur += num_paths_below(child)
+ _num_paths_above[i] = cur
+ return cur
+
+def sum_paths_above(i):
+ if _sum_paths_above[i] != None:
+ return _sum_paths_above[i]
+ if par[i] == None:
+ return 0
+ for child, distance in adj_list[par[i]]:
+ if child == i: par_dist = distance
+ cur = sum_paths_above(par[i]) + par_dist * num_paths_above(par[i])
+ for child, distance in adj_list[par[i]]:
+ if child == i: continue
+ if child == par[par[i]]: continue
+ cur += sum_paths_below(child) + (par_dist + distance) * num_paths_below(child)
+ _sum_paths_above[i] = cur
+ return cur
+
+road_index = int(input())
+
+rx, ry, road_distance = roads[road_index - 1]
+if par[rx] == ry: rx, ry = ry, rx
+# Now, par[ry] = rx.
+
+sum_x_size = sum_paths_above(rx)
+for child, distance in adj_list[rx]:
+ if child == ry: continue
+ if child == par[rx]: continue
+ sum_x_size += sum_paths_below(child) + distance * num_paths_below(child)
+
+sum_y_size = sum_paths_below(ry)
+
+num_x_size = num_paths_above(rx)
+for child, distance in adj_list[rx]:
+ if child == ry: continue
+ if child == par[rx]: continue
+ num_x_size += num_paths_below(child)
+
+num_y_size = num_paths_below(ry)
+
+total_productivity_lost = sum_x_size * num_y_size + sum_y_size * num_x_size + road_distance * num_x_size * num_y_size
+print(total_productivity_lost)
+
Misc 3
Hint 1
This problem is kind of an either you get it or you don’t problem, so its hard to give hints.
Your first course of action should be deciding on an encoding for all positions in the game. You’ll need to keep track of the current position, as well as what pieces of the di-force you’ve collected. This is important as it determines where you need to go, which affects the expectation.
Hint 2
Try writing out a recursive formula for the expectation of the game ending at any particular state in the game.
For the end state (boss position, all of the di-force collected), the expectation is $0$. For all positions next to the boss, the state with all of the di-force collected will look like $\mathbb{E}(P) = 1 + \frac{1}{4} \times 0 + \frac{1}{4} \mathbb{E}(X) + \frac{1}{4} \mathbb{E}(Y) + \frac{1}{4} \mathbb{E}(Z)$, where $X, Y$ and $Z$ are possible positions one could move to (They could also be $P$!)
Solution
Take all possible states of the game, and we’ll make a recursive formula for the expected number of steps to end the game from that position.
If we do this, we’ll end up with a system of equations, with $N$ equations and $N$ unknowns. You can take it for granted that this board produces an actual expected value for all valid locations. So we can solve this using gaussian elimination to figure out the solution to all variables simultaneously!
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
+95
+96
+97
+98
+99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+
import sys
+
+n, m = list(map(int, input().split()))
+
+lines = [list(input()) for _ in range(n)]
+
+n_objectives = 2
+
+for x in range(n):
+ for y in range(m):
+ if lines[x][y] == "S":
+ spawn = (x, y)
+ elif lines[x][y] == "A":
+ lines[x][y] = 1
+ elif lines[x][y] == "B":
+ lines[x][y] = 1<<1
+ elif lines[x][y] == "G":
+ goal = (x, y)
+
+def node_encode(x, y, cur_state):
+ return m*x + y + cur_state * n * m
+
+def node_decode(pos):
+ cur_state = pos // (n * m)
+ if cur_state == 1<<n_objectives:
+ return "constant"
+ pos %= n * m
+ x = pos // m
+ pos %= m
+ y = pos
+ return x, y, cur_state
+
+matrix = [
+ [
+ 0 for _ in range(n*m*(1<<n_objectives) + 1)
+ ] for _ in range(n*m*(1<<n_objectives))
+]
+
+for x in range(n):
+ for y in range(m):
+ if lines[x][y] == "X":
+ continue
+ for cur_state in range(1<<n_objectives):
+ # From this, we can move U/D/L/R
+ this_node = node_encode(x, y, cur_state)
+ if (x, y) == goal and cur_state == (1<<n_objectives)-1:
+ # We are done!
+ matrix[this_node][this_node] = 1
+ continue
+ options = []
+ for a, b in [
+ (x+1, y),
+ (x-1, y),
+ (x, y+1),
+ (x, y-1),
+ ]:
+ new_state = cur_state
+ if not (0 <= a < n and 0 <= b < m):
+ a, b = x, y
+ if lines[a][b] == "X":
+ a, b = x, y
+ if type(lines[a][b]) == int:
+ new_state |= lines[a][b]
+ options.append(node_encode(a, b, new_state))
+ # negate so the postive values are equal to this.
+ matrix[this_node][this_node] = -1
+ # 1 more step
+ matrix[this_node][-1] = 1
+ for option in options:
+ matrix[this_node][option] += 1/4
+
+def reduced_row_echelon_form(matrix):
+ rowCount = len(matrix)
+ colCount = len(matrix[0])
+ lead = 0
+ for r in range(rowCount):
+ if colCount <= lead: return
+ i = r
+ while matrix[i][lead] == 0:
+ i += 1
+ if rowCount == i:
+ i = r
+ lead = lead + 1
+ if colCount == lead:
+ return
+ # swap rows i and r
+ matrix[i], matrix[r] = matrix[r], matrix[i]
+ if matrix[r][lead] != 0:
+ div = matrix[r][lead]
+ for c in range(colCount):
+ matrix[r][c] /= div
+ for i in range(rowCount):
+ if i != r:
+ deduction = matrix[i][lead]
+ for c in range(colCount):
+ matrix[i][c] -= deduction * matrix[r][c]
+ lead += 1
+
+reduced_row_echelon_form(matrix)
+
+start_node = node_encode(*spawn, 0)
+# Find the row with col value at start position equal to 1.
+for r in range(len(matrix)):
+ if matrix[r][start_node] != 0:
+ print("non-zero row", file=sys.stderr)
+ for idx, val in enumerate(matrix[r]):
+ if val != 0:
+ print(node_decode(idx), val, file=sys.stderr)
+
+ print(-matrix[r][-1])
+
Recursion 0
Hint 1
This problem is purely about implementation - there’s no tricks, you just need to simulate the sequence.
Make sure you are performing the MOD operation!
Hint 2
If you’re struggling with the implementation - search up a solution which computes the fibonacci numbers, and try translate it to this sequence.
Solution
As the hints say, this is purely an implementation problem. Because our recurrence looks back two steps, we need two temporary variables to store the current values in the sequence.
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+
MOD = int(1e9+7)
+
+g1 = 1 # After i iterations, g1 = g(i+1)
+g0 = 0 # After i iterations, g0 = g(i)
+
+n = int(input())
+for _ in range(n):
+ # set j = i + 1
+ # g(j+1) = 3*g(i+1) + g(i)
+ # g(j) = g(i+1)
+ g1, g0 = (3*g1 + g0) % MOD, g1
+
+print(g0)
+
Recursion 1
Hint 1
This is a rather famous problem - there is a well known formula for $F(2n)$ and $F(2n+1)$
Hint 2
If we repeatedly apply the rules for $F(2n)$ and $F(2n+1)$, We’ll have about $\log_2(n)$ unique values of $F$ we’ll need to compute. If we don’t cache our results though, you’ll run into issues.
Solution
Note: for this problem we have F(0) = 1, but usually, F(0) = 0. So for this solution, assume F(0) = 0, then we can simply output F(n+1) in our solution rather than F(n).
As the hints hint at, there is a well known formula for $F(2n)$ and $F(2n+1)$, rather than giving it to you, let’s prove it!
This proof makes use of a very unique way to generate fibonacci numbers - with matricies!
Notice that the following matrix:
\[M = \begin{bmatrix} 1 & 1 \\ 1 & 0 \end{bmatrix}\]Paired with the following initial vector:
\[V = \begin{bmatrix} 1\\ 0 \end{bmatrix}\]Can be used to generate fibonacci numbers! Let’s find out what happens when we repeatedy multiply $V$ by $M$
\[MV = \begin{bmatrix} 1\\ 1 \end{bmatrix}. M^2V = \begin{bmatrix} 2\\ 1 \end{bmatrix}. M^3V = \begin{bmatrix} 3\\ 2 \end{bmatrix}. M^4V = \begin{bmatrix} 5\\ 3 \end{bmatrix}\]Spotting the pattern? In general, if we have
\[V = \begin{bmatrix} F(n)\\ F(n-1) \end{bmatrix}. MV = \begin{bmatrix} F(n) + F(n-1)\\ F(n) \end{bmatrix} = \begin{bmatrix} F(n+1)\\ F(n) \end{bmatrix}\]The matrix $M$ moves each of the values in the vector along one step in the fibonacci sequence! This is because the top row of $M$ adds the two vector values together, and the bottom row of $M$ just preserves the top value of the vector.
Repeatedly multiplying matrix $M$ you’ll find that
\[M^n = \begin{bmatrix} F(n+1) & F(n)\\ F(n) & F(n-1) \end{bmatrix}\]And so if
\[M^{2n}\begin{bmatrix} 1\\ 0 \end{bmatrix} = \begin{bmatrix} F(2n+1)\\ F(2n) \end{bmatrix}\]Then since
\[M^{2n} = M^n \times M^n = \begin{bmatrix} F(n+1) & F(n)\\ F(n) & F(n-1) \end{bmatrix} \times \begin{bmatrix} F(n+1) & F(n)\\ F(n) & F(n-1) \end{bmatrix} = \begin{bmatrix} F(n+1)^2 + F(n)^2 & F(n+1)F(n) + F(n)F(n-1)\\ F(n+1)F(n) F(n)F(n-1) & F(n)^2 + F(n-1)^2 \end{bmatrix}\]we have (after multiplying this by $V$):
\[F(2n+1) = F(n+1)^2 + F(n)^2, F(2n) = F(n+1)F(n) + F(n)F(n-1)\]Using this rule, we can solve the problem in logarithmic time, provided you use dynamic programming.
That being said, since we have a matrix that generates the fibonacci sequence, we can just use fast matrix exponentation to find our result too! No need for the fancy formula.
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+
from functools import cache
+
+MOD = int(1e9+7)
+
+@cache
+def fibonacci(n):
+ if n <= 3:
+ return [0, 1, 1, 2][n]
+ m = n // 2
+ if n%2 == 0:
+ return (fibonacci(m+1)*fibonacci(m) + fibonacci(m)*fibonacci(m-1)) % MOD
+ return (fibonacci(m+1)*fibonacci(m+1) + fibonacci(m)*fibonacci(m)) % MOD
+
+print(fibonacci(int(input()) + 1))
+
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+
import math
+
+MOD = int(1e9+7)
+
+matrix = [
+ [1, 1],
+ [1, 0]
+]
+
+def mat_mult(m1, m2):
+ result = [[0 for _ in range(len(m2[0]))] for _ in range(len(m1))]
+ for x in range(len(m1)):
+ for y in range(len(m2[0])):
+ res = 0
+ for a in range(len(m2)):
+ res += m1[x][a] * m2[a][y]
+ res = res % MOD
+ result[x][y] = res
+ return result
+
+def exponentiate(mat, p):
+ cur_val = [[int(i1==i2) for i2 in range(len(mat[0]))] for i1 in range(len(mat))]
+ cur_power = mat
+ for shift in range(2+math.floor(math.log2(p))):
+ if (1 << shift) & p:
+ cur_val = mat_mult(cur_val, cur_power)
+ cur_power = mat_mult(cur_power, cur_power)
+ return cur_val
+
+n = int(input()) + 1
+move = exponentiate(matrix, n)
+moved = mat_mult(move, [[1], [0]])
+print(int(moved[1][0]))
+
Recursion 2
Hint 1
If you haven’t already, give the solution to Recursion 1 a look, even if you solved the problem. Some of the tools there might be useful.
Hint 2
You need to use the matrix solution from recursion 1, however our new matrix needs to compute $G$ rather than $F$. How can we do this?
Solution
As the hints say, we need to come up with a matrix $M$ that when multiplied by a vector $V$, moves it through the $G$ sequence.
The trick here is to notice that while a $2 \times 2$ matrix makes this impossible, we can achieve this if we have a $4\times 4$ matrix, by keeping two rows for tracking the fibonacci sequence and two rows for tracking $G$:
\[M = \begin{bmatrix} 1 & 1 & 0 & 0 \\ 1 & 0 & 0 & 0 \\ 1 & 0 & 2 & 3 \\ 0 & 0 & 1 & 0 \end{bmatrix}\]Notice how the top left corner is the same matrix that generates $F$. So if the first two rows of $V$ are the same, then the first two rows of $V$ continue to generate the fibonacci numbers.
The bottom row of $M$ does the same thing as the second row did in the original - It keeps the bottom row one iteration before the third row.
The only remaining row - the third one - does the actual calcuation. $G(n) = 2 \times G(n-1) + 3 \times G(n-2) + F(n-1)$. Notice that since $G(n)$ is evaluated in the matrix the same time as $F(n)$, adding the top row from the previous iteration constitutes $F(n-1)$.
Multiplying this with \(V = \begin{bmatrix}F(1)\\ F(0)\\ G(1)\\ G(0)\end{bmatrix}\) advances the vector into \(V = \begin{bmatrix}F(2)\\ F(1)\\ G(2)\\ G(1)\end{bmatrix}\)
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+
import math
+
+MOD = int(1e9+7)
+
+matrix = [
+ [1, 1, 0, 0],
+ [1, 0, 0, 0],
+ [1, 0, 2, 3],
+ [0, 0, 1, 0]
+]
+
+def mat_mult(m1, m2):
+ result = [[0 for _ in range(len(m2[0]))] for _ in range(len(m1))]
+ for x in range(len(m1)):
+ for y in range(len(m2[0])):
+ res = 0
+ for a in range(len(m2)):
+ res += m1[x][a] * m2[a][y]
+ res = res % MOD
+ result[x][y] = res
+ return result
+
+def exponentiate(mat, p):
+ cur_val = [[int(i1==i2) for i2 in range(len(mat[0]))] for i1 in range(len(mat))]
+ cur_power = mat
+ for shift in range(2+math.floor(math.log2(p))):
+ if (1 << shift) & p:
+ cur_val = mat_mult(cur_val, cur_power)
+ cur_power = mat_mult(cur_power, cur_power)
+ return cur_val
+
+n = int(input())
+if n == 0:
+ print(0)
+else:
+ move = exponentiate(matrix, n)
+ moved = mat_mult(move, [[1], [1], [5], [0]])
+ # 2 1 11 5
+ # 3 2 39 11
+ print(int(moved[3][0]))
+
Maybe there’s a special kind of rule here too, but I haven’t gone looking for one.
Recursion 3
Hint 1
If you’ve solved recursion 2, this is just a harder version of that. Try come up with a matrix that computes the sequence. You may need multiple intermediate sequences.
Hint 2
The particular sequence $B(n) = B(n-1) + A(n), B(0) = A(0)$ may be useful. What is $B(n)$ in closed form?
What about the sequence $C(n) = C(n-1) + B(n), C(0) = B(0)$?
Solution
Note: I’m again going to assume $F(0)=0$, then we can simply translate $H(n) = 4\times H(n-1) + \sum_{i=1}^{n-1} ((n-i)^2\times F(i-1))$ into $H(n) = 4\times H(n-1) + \sum_{i=0}^{n-1} ((n-i)^2\times F(i))$, since $n \times 0 = 0$.
Let’s first tackle some of the questions in Hint 2, on our quest to find a matrix.
$B(n) = \sum_{i=0}^{n} A(i)$, and $C(n) = \sum_{i=0}^n B(i) = \sum_{i=0}^n \sum_{j=0}^i A(j) = \sum_{i=0}^n (n + 1 -i)A(i)$
Interesting… What about $D(n) = D(n-1) + C(n)$?
This would have closed form $D(n) = \sum_{i=0}^n C(i) = \sum_{i=0}^n \sum_{j=0}^i (i-j+1)A(j) = \sum_{i=0}^n \frac{(n-i+1)(n-i+2)}{2} A(i) = \sum_{i=0}^n \frac{1}{2}((n-i)^2 + 3(n-i) + 2) A(i)$.
Very interesting…
What if we took $2D(n) - 3C(n) + B(n)$? That would give us $\sum_{i=0}^n ((n-i)^2 + 3(n-i) + 2 - 3(n-i) - 3 + 1) A(i) = \sum_{i=0}^n (n-i)^2 A(i)$. That’s pretty much it! Just the top boundary of the sum is wrong.
What if we instead defined $B(n) = A(n-1) + B(n-1), B(0) = 0$? Then we’d have $B(n) = \sum_{i=0}^{n-1} A(i)$.
Keeping the definition of $C$, we’d have $C(n) = \sum_{i=0}^n B(i) = \sum_{i=0}^n \sum_{j=0}^{i-1} A(j) = \sum_{i=0}^{n-1} (n-i)A(i)$.
Keeping the definition of $D$, we’d have $D(n) = \sum_{i=0}^n C(i) = \sum_{i=0}^n \sum_{j=0}^{i-1} (n-j)A(j) = \sum_{i=0}^{n-1} \frac{(n-i)(n-i+1)}{2}A(i) = \sum_{i=0}^{n-1} \frac{1}{2}((n-i)^2 + (n-i))A(i)$
Using this new definition, what is $2D(n) - C(n)$?
\[\sum_{i=0}^{n-1} ((n-i)^2 + (n-i) - (n-i))A(i) = \sum_{i=0}^{n-1} (n-i)^2A(i)\]Bingo!
Using a similar strategy to before, we define our matrix:
\[\begin{bmatrix} 1 & 1 & 0 & 0 & 0 & 0 & 0\\ 1 & 0 & 0 & 0 & 0 & 0 & 0\\ 1 & 0 & 1 & 0 & 0 & 0 & 0\\ 0 & 0 & 1 & 1 & 0 & 0 & 0\\ 0 & 0 & 0 & 1 & 0 & 0 & 0\\ 0 & 0 & 0 & 1 & 0 & 1 & 0\\ 0 & 0 & 0 & 0 & -1 & 2 & 4\\ \end{bmatrix}\]Let’s analyse this row by row.
- Rows 1 and 2 compute $F(n+2)$ and $F(n+1)$ respectively. You’ll see the reason for the offsets later.
- Row 3 computes $B(n+2)$ (From now on, assume we’ve set $A = F$). This adds $F(n+2)$ to $B(n+2)$.
- Row 4 computes $C(n+1)$. This adds $B(n+2)$ to $C(n+1)$.
- Row 5 computes $C(n)$. This will be needed for later.
- Row 6 computes $D(n)$. This adds $C(n+1)$ to $D(n)$.
- Row 7 computes $H(n)$. This adds $4H(n)$ to $2D(n)$ and subtracts $C(n)$ from the result.
With all this in place, we need only define our vector $V$, to contain $F(2), F(1), B(2), C(1), C(0), D(0), H(0)$.
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+
import math
+
+MOD = int(1e9+7)
+
+matrix = [
+ [1, 1, 0, 0, 0, 0, 0], # fn+2
+ [1, 0, 0, 0, 0, 0, 0], # fn+1
+ [1, 0, 1, 0, 0, 0, 0], # bn+2
+ [0, 0, 1, 1, 0, 0, 0], # cn+1
+ [0, 0, 0, 1, 0, 0, 0], # cn
+ [0, 0, 0, 1, 0, 1, 0], # dn
+ [0, 0, 0, 0, -1, 2, 4], #hn
+]
+
+def mat_mult(m1, m2):
+ result = [[0 for _ in range(len(m2[0]))] for _ in range(len(m1))]
+ for x in range(len(m1)):
+ for y in range(len(m2[0])):
+ res = 0
+ for a in range(len(m2)):
+ res += m1[x][a] * m2[a][y]
+ res = res % MOD
+ result[x][y] = res
+ return result
+
+def exponentiate(mat, p):
+ cur_val = [[int(i1==i2) for i2 in range(len(mat[0]))] for i1 in range(len(mat))]
+ cur_power = mat
+ for shift in range(2+math.floor(math.log2(p+1))):
+ if (1 << shift) & p:
+ cur_val = mat_mult(cur_val, cur_power)
+ cur_power = mat_mult(cur_power, cur_power)
+ return cur_val
+
+n = int(input())
+move = exponentiate(matrix, n)
+
+col = [[2], [1], [2], [1], [0], [0], [0]]
+
+moved = mat_mult(move, col)
+print(int(moved[6][0]))
+
+# 0, 1, 9, 51
+# 4*0 + 1 * F(1) = 1 OR 2*1-1
+# 4*1 + 4 * F(1) + 1 * F(2) = 9
+# 4*9 + 9 * F(1) + 4 * F(2) + 1 * F(3) = 36 + 9 + 4 + 2 = 51
+
Note that we can generate matricies using this trick for any upper bound on the sum, not just $n-1$, and any polynomial of $n-i$, not just $(n-i)^2$. I’d like to encourage you to think of other cool recursive sequences we can model via matrix multiplication.
This post is licensed under GNU GPL V3 by the author. Comments powered by Disqus.
diff --git a/posts/factorization/index.html b/posts/factorization/index.html
new file mode 100644
index 0000000..cdd671c
--- /dev/null
+++ b/posts/factorization/index.html
@@ -0,0 +1,261 @@
+ Primes and Factorization Techniques | Monash Code Binder Primes and Factorization Techniques
Why?
Many number theoretic problems in competitive programming require analysing the factors or prime factors of a number. Here I’ll list a few techniques for finding these factors, and some techniques / properties involving the factors / prime factors of a number.
Preliminaries
First off, lets define our basic terms, then we can get into the interesting stuff.
Definitions
For a positive integer \(x\), a positive integer \(y\) is a factor of \(x\) if \(x\) can be expressed as \(yk\) for some other positive integer \(k\). A factor \(y\) is a prime factor if \(y\) is prime.
A factorization of a positive integer \(x\) is a collection of (not necessarily distinct) factors \(y_i\) satisfying \(y_0 \times y_1 \times \cdots \times y_k = x\). A prime factorization is a factorization where all \(y_i\) are prime factors.
All the below properties depend on the fact that 1 is not a prime number. I repeat: in all algorithms from now on (and I think you should adopt this generally too) we do not consider 1 to be a prime number!
Basic Properties
Every positive integer has a single unique prime factorization. I’ll leave it as an exercise for you to prove this is true with induction.
Let’s suppose \(x\) can be written in the form
\[x = y_0^{a_0} \times y_1^{a_1} \times y_2^{a_2} \times \cdots ,\]where $y_i$ are distinct primes. Then there are \((a_0+1)\times(a_1+1)\times(a_2+1)\times\cdots\) factors of $x$, since any factor can use anywhere from \(0\) to \(a_i\) copies of prime \(p_i\).
Given the prime factorization of two numbers \(x\) and \(y\), we can find the gcd simply by considering the intersection of their factorizations. The lcm of those two numbers is the union of their factorizations.
Finding Factors / Primes
Sieve of Eratosthenes
A simple (but powerful) approach is that of the Erathosthenes: Go through each prime, and mark all multiples of that prime as not prime.
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+
def sieve(n):
+ isprime = [True]*(n+1)
+ isprime[0], isprime[1] = False, False
+ i=2
+ while i*i <= n:
+ if not isprime[i]:
+ i += 1
+ continue
+ j=i*i
+ while j <= n:
+ isprime[j] = False
+ j += i
+ i += 1
+ return isprime
+
1
+2
+3
+4
+5
+6
+7
+8
+
vector<bool> isprime;
+void sieve(int n) {
+ isprime.assign(n+1, 1);
+ isprime[0] = isprime[1] = 0;
+ for (int i=2; i * i <= n; i++) if (isprime[i]) {
+ for (int j=i*i; j <= n; j+= i) isprime[j] = 0;
+ }
+}
+
There is some optimization going on here. We only want to mark off multiples of \(i\) in the isprime
vector where \(i\) is the smallest prime factor of that multiple. This is the reason we can end when \(i \times i > n\) and start at \(j = i \times i\).
This operation is contest safe until about \(10^8\).
Factors with Sieve
While finding whether a number is prime is all well and good, rather than just recording primality, why not also record what prime factor caused us not to be prime? From this information we can then generate prime factors of a number, if we wish.
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+
def factor_sieve(n):
+ # factor[i] = a factor of i. factor[1] = 0.
+ factor = [0]*(n+1)
+ primes = []
+ for i in range(2, n+1):
+ if factor[i] == 0:
+ factor[i] = i
+ primes.append(i)
+ # Make sure that the smallest factor is always listed.
+ for p in primes:
+ if (p > factor[i] or i * p > n): break
+ factor[i * p] = p
+ return factor, primes
+
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+
vi factor, primes;
+void factor_sieve(int n) {
+ // factor[i] = a factor of i. factor[1] = 0.
+ factor.assign(n+1, 0);
+ for (int i=2; i <= n; i++) {
+ if (factor[i] == 0) {
+ factor[i] = i;
+ primes.push_back(i);
+ }
+ // Make sure that the smallest factor is always listed.
+ for (int p: primes) if (p > factor[i] || i * p > n) break; else factor[i * p] = p;
+ }
+}
+
This operation is contest safe until about \(10^7\).
Prime Factorizations with precomp
If we want to get a full list of prime factors given we’ve done the above precomputation, we can just repeatedly divide i
by factor[i]
until we hit 1.
1
+2
+3
+4
+5
+6
+7
+8
+9
+
def fast_factors(n, factor):
+ res = []
+ while n > 1:
+ f = factor[n]
+ # Remove this line to include duplicate factors.
+ while n % f == 0:
+ n //= f
+ res.append(f)
+ return res
+
1
+2
+3
+4
+5
+6
+7
+8
+9
+
vi fast_factors(int n) {
+ vi res;
+ while (n > 1) {
+ int f = factor[n];
+ while (n % f == 0) n /= f; // Remove the while to include duplicate factors.
+ res.push_back(f);
+ }
+ return res;
+}
+
Prime factors without precomp
If you are only testing a few numbers for prime factors, you can instead do this in \(O(\sqrt{n})\):
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+
def slow_factors(n):
+ res = []
+ i = 2
+ while i*i <= n:
+ # Change to while for duplicates
+ if n % i == 0:
+ res.append(i)
+ # Remove while for duplicates.
+ while n % i == 0:
+ n //= i
+ i += 1
+ if n > 1:
+ res.append(n)
+ return res
+
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+
vll slow_factors(ll n) {
+ vll res;
+ for (ll i=2; i*i <= n; i++)
+ if (n % i == 0) { // change to while for duplicates
+ res.push_back(i);
+ while (n % i == 0) n /= i; // Remove while for duplicates
+ }
+ if (n > 1) res.push_back(n);
+ return res;
+}
+
Fast Prime Testing
The first three algorithms are good when you need to know the primality / factors of numbers in a range. If you just want to test a single number, we can do much better though. As we’ve seen we can get the prime factorization in \(O(\sqrt{n})\), but for just primality testing, we can do \(O(\log(n))\) with Miller-Rabin.
I won’t go into details how this works, but we provide Miller-Rabin with some bases depending on how big our number is:
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+
val = [2, 7, 61] # for n <= 2^32
+val = [2, 13, 23, 1662803] # for n <= 10^12
+val = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37] # for n <= 2^64
+
+def isprime(n):
+ if n < 2: return False
+ d = n - 1
+ s = 0
+ while d & 1 == 0:
+ d >>= 1
+ s += 1
+ for v in val:
+ if v >= n: break
+ # v^d mod n.
+ x = expmod(v, d, n)
+ if (x == 1 or x == n-1): continue
+ for r in range(s):
+ x = expmod(x, 2, n)
+ if (x == n - 1): break
+ else:
+ return False
+ return True
+
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+
vi val = {2, 7, 61}; // n <= 2^32
+vi val = {2, 13, 23, 1662803}; // n <= 10^12
+vi val = {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}; // n <= 2^64 (replace ll with __int128)
+
+bool isprime(ll n) {
+ if (n < 2) return false;
+ ll s = __builtin_ctzll(n-1), d = (n - 1) >> s;
+ for (int v: val) {
+ if (v >= n) break;
+ ll x = expmod(v, d, n);
+ if (x == 1 || x == n-1) continue;
+ // big is typedef'd to __int128
+ for (ll r=1; r<s; r++) if ((x = ((big(x)*x) % n)) == n-1) goto nextPrime;
+ return false;
+ nextPrime:;
+ }
+ return true;
+}
+
You can learn more about Deterministic Miller-Rabin here.
Cool properties
Since factors and primes can be thought of as the building blocks of all the integers and to some extent are still shrouded in mystery, there are lots of cool properties of primes that we might want to utilise in programming problems.
Squares
The prime factorization of squares has \(a_i\) even for all \(i\), where
\[x = y_0^{a_0}\times y_1^{a_1} \times y_2^{a_2} \times \cdots.\]Another way of thinking about this is that every square number can be factorized into prime squares.
In general for powers of 3, 4, etc. Simply replace “\(a_i\) even” with “\(a_i\) divisible by 3, 4, etc”.
Sums
Suppose we add two coprime numbers \(A\) and \(B\) (Numbers with gcd\((A, B) = 1\)). We know their sum must then be coprime to both \(A\) and \(B\) (\(\text{gcd}(A+B, A) = \text{gcd}(A+B, B) = 1\)). In other words, the prime factorization of \(A+B\) comprises of primes which are neither in \(A\) nor \(B\).
Factorization Wheels
While going through with the Sieve method is good for finding the composite / prime status of every number, if you just want to list off primes it might be better to combine with a wheel factorization method, allowing you to strike off many candidates based on the first few factors.
You can find out about wheel factorization here.
Sieve starting at different positions
As we’ve seen previously, the sieve of Eratosthenes only needs to consider primes less than \(\sqrt{n}\) for the algorithm to work (This can be reasoned by noting that any composite number requires a factor less than or equal to \(\sqrt{n}\)).
Therefore, we can modify our solution slightly to compute the prime numbers in some other range. For example, to compute the primes between \(10^{12} - 10^7\) and \(10^{12}\):
- Use the sieve to find all primes less than \(\sqrt{10^{12}} = 10^6\).
- Sieve over these primes in the range \(10^{12} - 10^7\) to \(10^{12}\). Anything not marked as composite must be prime!
We can actually continue this argument with contiguous segments to get a \(\sqrt{n}\) space complexity sieve implementation. For details see here
Related Problems
This post is licensed under GNU GPL V3 by the author. Comments powered by Disqus.
diff --git a/posts/index.html b/posts/index.html
new file mode 100644
index 0000000..c31a99c
--- /dev/null
+++ b/posts/index.html
@@ -0,0 +1,11 @@
+
+
+
+ Redirecting…
+
+
+
+
+ Redirecting…
+ Click here if you are not redirected.
+
diff --git a/posts/lca/index.html b/posts/lca/index.html
new file mode 100644
index 0000000..93dd78d
--- /dev/null
+++ b/posts/lca/index.html
@@ -0,0 +1,1083 @@
+ Least Common Ancestor (LCA) | Monash Code Binder Least Common Ancestor (LCA)
Where is this useful?
The Least Common Ancestor (LCA) data structure is useful wherever you have a directed graph where every vertex has out-degree \(\leq 1\). In more common terms, each vertex has a unique determined ‘parent’, or it is a root node, with no parent. The most common (and almost always only) example being a rooted tree.
On these particular graphs, the LCA gives us a fast way to move ‘up’ the graph (Towards your parents). In particular, we can use this to find the least common ancestor in \(\log (N)\) time, where the data structure gets its name from.
Reusing the analogy of parenting vertices, a vertex \(u\) is an ancestor of \(v\) if \(u\) is \(v\)’s parent, or \(v\)’s parent’s parent, and so on. As long as there is a line of ‘parentage’ connecting \(v\) to \(u\), \(u\) is an ancestor of \(v\). We consider \(v\) to also be it’s own ancestor.
The least common ancestor problem then requires, given two vertices \(x\) and \(y\), to find a vertex \(z\) in the graph such that \(z\) is an ancestor of both \(x\) and \(y\), but there is no vertex \(z’ \neq z\) such that \(z\) is an ancestor of \(z’\) and \(z’\) is an ancestor of \(x\) and \(y\) (In the tree example, we just want to find the lowest depth vertex whose subtree contains boths \(x\) and \(y\)).
Note that the least common ancestor can be \(x\) or \(y\), if \(x\) is an ancestor of \(y\) or vice-versa.
Implementing the Data Structure
Interface
Let’s start by defining an interface for this data structure, and then slowly implement our methods.
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+
class LCA:
+"""
+vertices are represented as numbers 0->n-1.
+"""
+
+ def __init__(self, n_vertices):
+ self.n = n_vertices
+ self.adjacent = [[] for _ in range(self.n)]
+
+ def add_edge(self, u, v, weight=1):
+ self.adjacent[u].append((v, weight))
+ self.adjacent[v].append((u, weight))
+
+ def build(self, root):
+ # Once edges are added, build the tree/data structure.
+ pass # TODO
+
+ def query(self, u, v, root=None):
+ # What is the lowest common ancestor of u, v?
+ # Extension: Make this query from any root vertex you want.
+ pass # TODO
+
+ def dist(self, u, v):
+ # Find the distance between two vertices - very simple if we have LCA.
+ pass # TODO
+
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+
template<typename T = int> struct LCA {
+ // vertices are represented as numbers 0->n-1.
+ int n; vector<vector<pair<int, T> > > adjacent;
+
+ LCA(int n_vertices) : n(n_vertices), adjacent(n) { }
+
+ void add_edge(int u, int v, T weight=1) {
+ adjacent[u].emplace_back(v, weight);
+ adjacent[v].emplace_back(u, weight);
+ }
+
+ void build(int root=0) {
+ // Once edges are added, build the tree/data structure.
+ // TODO
+ }
+
+ int query(int u, int v, int root=-1) {
+ // What is the lowest common ancestor of u, v?
+ // Extension: Make this query from any root vertex you want.
+ // TODO
+ }
+
+ T dist(int u, int v) {
+ // Find the distance between two vertices - very simple if we have LCA.
+ // TODO
+ }
+}
+
Useful data
First off, let’s save some intermediary data that will make our life a lot easier, and strictly define the tree structure. We’ll introduce three arrays: parent
, level
and length
.
parent
stores the direct parent of any vertex in the rooted tree.level
stores the level of the tree the vertex is at (Number of edges from it to the root)length
stores the length of the vertex to the root (Using edge weights).
We’ll populate these fields in the build
method, since all edges should be added by then.
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+
class LCA:
+"""
+vertices are represented as numbers 0->n-1.
+"""
+
+ def __init__(self, n_vertices):
+ self.n = n_vertices
+ self.adjacent = [[] for _ in range(self.n)]
+
+ def add_edge(self, u, v, weight=1):
+ self.adjacent[u].append((v, weight))
+ self.adjacent[v].append((u, weight))
+
+ def dfs(self, source, c_parent, c_level, c_length):n
+ # Search from the source down the tree and set parent, level, length accordingly.n
+ self.parent[source] = c_parentn
+ self.level[source] = c_leveln
+ self.length[source] = c_lengthn
+ for child, weight in self.adjacent[source]:n
+ if child != c_parent:n
+ self.dfs(child, source, c_level + 1, c_length + weight)n
+
+ def build(self, root):
+ # Once edges are added, build the tree/data structure.
+ self.parent = [None]*self.nn
+ self.level = [None]*self.nn
+ self.length = [None]*self.nn
+ self.dfs(root, -1, 0, 0)n
+
+ def query(self, u, v, root=None):
+ # What is the lowest common ancestor of u, v?
+ # Extension: Make this query from any root vertex you want.
+ pass # TODO
+
+ def dist(self, u, v):
+ # Find the distance between two vertices - very simple if we have LCA.
+ pass # TODO
+
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+
template<typename T = int> struct LCA {
+ // vertices are represented as numbers 0->n-1.
+ int n; vector<vector<pair<int, T> > > adjacent;
+ vi parent, level;n
+ vector<T> length;n
+
+ LCA(int n_vertices) : n(n_vertices), adjacent(n), parent(n), level(n), length(n) { }m
+
+ void add_edge(int u, int v, T weight=1) {
+ adjacent[u].emplace_back(v, weight);
+ adjacent[v].emplace_back(u, weight);
+ }
+
+ void dfs(int source, int c_parent, int c_level, T c_length) {n
+ // Search from the source down the tree and set parent, level, length accordingly.n
+ parent[source] = c_parent;n
+ level[source] = c_level;n
+ length[source] = c_length;n
+ for (auto v: adjacent[source])n
+ if (v.first != c_parent)n
+ dfs(v.first, source, c_level+1, c_length+v.second);n
+ }
+
+ void build(int root=0) {
+ // Once edges are added, build the tree/data structure.
+ dfs(root, -1, 0, 0);n
+ }
+
+ int query(int u, int v, int root=-1) {
+ // What is the lowest common ancestor of u, v?
+ // Extension: Make this query from any root vertex you want.
+ // TODO
+ }
+
+ T dist(int u, int v) {
+ // Find the distance between two vertices - very simple if we have LCA.
+ // TODO
+ }
+}
+
So now we can query many useful characteristics of vertices in rooted trees. Now for the interesting part: let’s start creating data unique to the LCA structure.
Ancestor Array
LCA gets its fast queries by precomputing a special array, called ancestor
. Ancestor is a 2 dimensional array with ancestor[v][k]
storing the ancestor of vertex v
\(2^k\) edges towards the root. As an example, ancestor[v][0]
is parent[v]
(Parent is just ancestor 1 edge towards the root), and ancestor[v][1]
is parent[parent[v]]
where appropriate (2 edges towards root is same as parent’s parent).
If you just populated this array by searching up the tree \(2^k\) steps each time, you’d have worst case complexity \(O(n^2)\) to build the array. Luckily, we can use the fact that ancestor[v][k] = ancestor[ancestor[v][k-1]][k-1]
(In other words, you can move \(2^k\) steps towards the root by first moving \(2^{k-1}\) steps, which we’ve already computed, and then another \(2^{k-1}\) steps from this new position). This reduces the complexity to \(O(n\log_2(n))\)
We do this so that we can find the ancestor \(m\) edges towards the root for any arbitrary \(m\) in \(\log_2(m)\) time, while only using \(\log_2(n)\) space. We’ll see how this gets done later.
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+
class LCA:
+"""
+vertices are represented as numbers 0->n-1.
+"""
+
+ # number such that 2^{MAX_LOG} > n. 20 works for n <= 10^6.n
+ MAX_LOG = 20n
+
+ def __init__(self, n_vertices):
+ self.n = n_vertices
+ self.adjacent = [[] for _ in range(self.n)]
+
+ def add_edge(self, u, v, weight=1):
+ self.adjacent[u].append((v, weight))
+ self.adjacent[v].append((u, weight))
+
+ def dfs(self, source, c_parent, c_level, c_length):
+ # Search from the source down the tree and set parent, level, length accordingly.
+ self.parent[source] = c_parent
+ self.level[source] = c_level
+ self.length[source] = c_length
+ for child, weight in self.adjacent[source]:
+ if child != c_parent:
+ self.dfs(child, source, c_level + 1, c_length + weight)
+
+ def build(self, root):
+ # Once edges are added, build the tree/data structure.
+ self.parent = [None]*self.n
+ self.level = [None]*self.n
+ self.length = [None]*self.n
+ self.dfs(root, -1, 0, 0)
+ # Compute ancestorn
+ self.ancestor = [[-1]*self.MAX_LOG for _ in range(self.n)]n
+ # Initial step: ancestor[v][0] = parent[v]n
+ for v in range(self.n):n
+ self.ancestor[v][0] = self.parent[v]n
+ # Now, compute ancestor[v][k] from 1->MAX_LOGn
+ for k in range(1, self.MAX_LOG):n
+ for v in range(self.n):n
+ if self.ancestor[v][k-1] != -1:n
+ # Move 2^{k-1} up, then 2^{k-1} again.n
+ self.ancestor[v][k] = self.ancestor[self.ancestor[v][k-1]][k-1]n
+
+ def query(self, u, v, root=None):
+ # What is the lowest common ancestor of u, v?
+ # Extension: Make this query from any root vertex you want.
+ pass # TODO
+
+ def dist(self, u, v):
+ # Find the distance between two vertices - very simple if we have LCA.
+ pass # TODO
+
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+
template<typename T = int> struct LCA {
+ // vertices are represented as numbers 0->n-1.
+ // number such that 2^{MAX_LOG} > n. 20 works for n <= 10^6.n
+ int MAX_LOG = 20;n
+ int n; vector<vector<pair<int, T> > > adjacent;
+ vi parent, level;
+ vvi ancestor;n
+ vector<T> length;
+
+ LCA(int n_vertices) : n(n_vertices), adjacent(n), parent(n), level(n), length(n) { }
+
+ void add_edge(int u, int v, T weight=1) {
+ adjacent[u].emplace_back(v, weight);
+ adjacent[v].emplace_back(u, weight);
+ }
+
+ void dfs(int source, int c_parent, int c_level, T c_length) {
+ // Search from the source down the tree and set parent, level, length accordingly.
+ parent[source] = c_parent;
+ level[source] = c_level;
+ length[source] = c_length;
+ for (auto v: adjacent[source])
+ if (v.first != c_parent)
+ dfs(v.first, source, c_level+1, c_length+v.second);
+ }
+
+ void build(int root=0) {
+ // Once edges are added, build the tree/data structure.
+ dfs(root, -1, 0, 0);
+ // Compute ancestorn
+ ancestor.assign(n, vi(MAX_LOG, -1));n
+ // Initial step: ancestor[v][0] = parent[v]n
+ for (int v=0; v<n; v++)n
+ ancestor[v][0] = parent[v];n
+ // Now, compute ancestor[v][k] from 1->MAX_LOGn
+ for (int k=1; k < MAX_LOG; k++)n
+ for (int v=0; v<n; v++)n
+ if (ancestor[v][k-1] != -1) {n
+ // Move 2^{k-1} up, then 2^{k-1} again.n
+ ancestor[v][k] = ancestor[ancestor[v][k-1]][k-1];n
+ }n
+ }
+
+ int query(int u, int v, int root=-1) {
+ // What is the lowest common ancestor of u, v?
+ // Extension: Make this query from any root vertex you want.
+ // TODO
+ }
+
+ T dist(int u, int v) {
+ // Find the distance between two vertices - very simple if we have LCA.
+ // TODO
+ }
+}
+
Query
That’s actually most of the ingenuity out of the way, now we can get to implementing query
.
Provided we want the LCA with respect to the root we called build
from, we can define the LCA l
of u
and v
in the following way:
l
is the ancestor of u
and v
maximising level[l]
.
We also know that level[l] <= min(level[u], level[v])
. Using this, we can calculate query(u, v)
by:
- Finding the ancestors of
u
and v
(call them a1
, a2
) such that level[a1] = level[a2] = min(level[u], level[v])
. - Keep moving
a1
and a2
towards the root (higher and higher ancestors) until a1 == a2
. Then a1
and a2
are the LCA of u
and v
.
We can do both of these things on \(\log_2(n)\) time with this ancestor
array we’ve generated. Let’s see how:
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+
class LCA:
+"""
+vertices are represented as numbers 0->n-1.
+"""
+
+ # number such that 2^{MAX_LOG} > n. 20 works for n <= 10^6.
+ MAX_LOG = 20
+
+ def __init__(self, n_vertices):
+ self.n = n_vertices
+ self.adjacent = [[] for _ in range(self.n)]
+
+ def add_edge(self, u, v, weight=1):
+ self.adjacent[u].append((v, weight))
+ self.adjacent[v].append((u, weight))
+
+ def dfs(self, source, c_parent, c_level, c_length):
+ # Search from the source down the tree and set parent, level, length accordingly.
+ self.parent[source] = c_parent
+ self.level[source] = c_level
+ self.length[source] = c_length
+ for child, weight in self.adjacent[source]:
+ if child != c_parent:
+ self.dfs(child, source, c_level + 1, c_length + weight)
+
+ def build(self, root):
+ # Once edges are added, build the tree/data structure.
+ self.parent = [None]*self.n
+ self.level = [None]*self.n
+ self.length = [None]*self.n
+ self.dfs(root, -1, 0, 0)
+ self.ancestor = [[-1]*self.MAX_LOG for _ in range(self.n)]
+ # Initial step: ancestor[v][0] = parent[v]
+ for v in range(self.n):
+ self.ancestor[v][0] = self.parent[v]
+ # Now, compute ancestor[v][k] from 1->MAX_LOG
+ for k in range(1, self.MAX_LOG):
+ for v in range(self.n):
+ if self.ancestor[v][k-1] != -1:
+ # Move 2^{k-1} up, then 2^{k-1} again.
+ self.ancestor[v][k] = self.ancestor[self.ancestor[v][k-1]][k-1]
+
+ def query(self, u, v, root=None):
+ # What is the lowest common ancestor of u, v?
+ # Extension: Make this query from any root vertex you want.
+
+ if root is not None:n
+ pass # TODOn
+ # assume that u is higher up than v, to simplify the code belown
+ if self.level[u] > self.level[v]:n
+ u, v = v, un
+ # STEP 1: set u and v to be ancestors with the same leveln
+ for k in range(self.MAX_LOG-1, -1, -1):n
+ if (self.level[v] - (1 << k) >= self.level[u]):n
+ # If v is 2^k levels below u, move it up 2^k levels.n
+ v = self.ancestor[v][k]n
+ # We can be certain that level[u] = level[v]. Reason: binary representation of all natural numbers.n
+ # Do we need to move to step 2?n
+ if (u == v): return un
+ # STEP 2: find the highest ancestor where u != v. Then the parent is the LCAn
+ for k in range(self.MAX_LOG-1, -1, -1):n
+ if (self.ancestor[u][k] != self.ancestor[v][k]):n
+ # Move up 2^k stepsn
+ u = self.ancestor[u][k]n
+ v = self.ancestor[v][k]n
+ return self.parent[u]n
+
+ def dist(self, u, v):
+ # Find the distance between two vertices - very simple if we have LCA.
+ pass # TODO
+
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+
template<typename T = int> struct LCA {
+ // vertices are represented as numbers 0->n-1.
+ // number such that 2^{MAX_LOG} > n. 20 works for n <= 10^6.
+ int MAX_LOG = 20;
+ int n; vector<vector<pair<int, T> > > adjacent;
+ vi parent, level;
+ vvi ancestor;
+ vector<T> length;
+
+ LCA(int n_vertices) : n(n_vertices), adjacent(n), parent(n), level(n), length(n) { }
+
+ void add_edge(int u, int v, T weight=1) {
+ adjacent[u].emplace_back(v, weight);
+ adjacent[v].emplace_back(u, weight);
+ }
+
+ void dfs(int source, int c_parent, int c_level, T c_length) {
+ // Search from the source down the tree and set parent, level, length accordingly.
+ parent[source] = c_parent;
+ level[source] = c_level;
+ length[source] = c_length;
+ for (auto v: adjacent[source])
+ if (v.first != c_parent)
+ dfs(v.first, source, c_level+1, c_length+v.second);
+ }
+
+ void build(int root=0) {
+ // Once edges are added, build the tree/data structure.
+ dfs(root, -1, 0, 0);
+ // Compute ancestor
+ ancestor.assign(n, vi(MAX_LOG, -1));
+ // Initial step: ancestor[v][0] = parent[v]
+ for (int v=0; v<n; v++)
+ ancestor[v][0] = parent[v];
+ // Now, compute ancestor[v][k] from 1->MAX_LOG
+ for (int k=1; k < MAX_LOG; k++)
+ for (int v=0; v<n; v++)
+ if (ancestor[v][k-1] != -1) {
+ // Move 2^{k-1} up, then 2^{k-1} again.
+ ancestor[v][k] = ancestor[ancestor[v][k-1]][k-1];
+ }
+ }
+
+ int query(int u, int v, int root=-1) {
+ // What is the lowest common ancestor of u, v?
+ // Extension: Make this query from any root vertex you want.
+ if (root != -1) {n
+ // TODOn
+ }n
+ // assume that u is higher up than v, to simplify the code belown
+ if (level[u] > level[v]) swap(u, v);n
+ // STEP 1: set u and v to be ancestors with the same leveln
+ for (int k=MAX_LOG-1, k>=0; k--)n
+ if (level[v] - (1 << k) >= level[u]) {n
+ // If v is 2^k levels below u, move it up 2^k levels.n
+ v = ancestor[v][k];n
+ }n
+ // We can be certain that level[u] = level[v]. Reason: binary representation of all natural numbers.n
+ // Do we need to move to step 2?n
+ if (u == v) return un
+ // STEP 2: find the highest ancestor where u != v. Then the parent is the LCAn
+ for (int k=MAX_LOG; k>=0; k--)n
+ if (ancestor[u][k] != ancestor[v][k]) {n
+ // Move up 2^k stepsn
+ u = ancestor[u][k];n
+ v = ancestor[v][k];n
+ }n
+ return parent[u];n
+ }
+
+ T dist(int u, int v) {
+ // Find the distance between two vertices - very simple if we have LCA.
+ // TODO
+ }
+}
+
Nice! That’s the main functionality of LCA completed.
Corrolaries
Let’s quickly tackle the two remaining implementations:
- Calculating the distance between two vertices
u
and v
is the same as calculating the distance between u
and query(u, v)
, and adding that to the distance between v
and query(u, v)
- Calculating the LCA from a particular root, just requires a slight change in perspective. For two vertices
u
and v
, and custom root r
, the LCA will always be one of query(u, v)
, query(u, r)
or query(v, r)
.
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+
class LCA:
+"""
+vertices are represented as numbers 0->n-1.
+"""
+
+ # number such that 2^{MAX_LOG} > n. 20 works for n <= 10^6.
+ MAX_LOG = 20
+
+ def __init__(self, n_vertices):
+ self.n = n_vertices
+ self.adjacent = [[] for _ in range(self.n)]
+
+ def add_edge(self, u, v, weight=1):
+ self.adjacent[u].append((v, weight))
+ self.adjacent[v].append((u, weight))
+
+ def dfs(self, source, c_parent, c_level, c_length):
+ # Search from the source down the tree and set parent, level, length accordingly.
+ self.parent[source] = c_parent
+ self.level[source] = c_level
+ self.length[source] = c_length
+ for child, weight in self.adjacent[source]:
+ if child != c_parent:
+ self.dfs(child, source, c_level + 1, c_length + weight)
+
+ def build(self, root):
+ # Once edges are added, build the tree/data structure.
+ self.parent = [None]*self.n
+ self.level = [None]*self.n
+ self.length = [None]*self.n
+ self.dfs(root, -1, 0, 0)
+ self.ancestor = [[-1]*self.MAX_LOG for _ in range(self.n)]
+ # Initial step: ancestor[v][0] = parent[v]
+ for v in range(self.n):
+ self.ancestor[v][0] = self.parent[v]
+ # Now, compute ancestor[v][k] from 1->MAX_LOG
+ for k in range(1, self.MAX_LOG):
+ for v in range(self.n):
+ if self.ancestor[v][k-1] != -1:
+ # Move 2^{k-1} up, then 2^{k-1} again.
+ self.ancestor[v][k] = self.ancestor[self.ancestor[v][k-1]][k-1]
+
+ def query(self, u, v, root=None):
+ # What is the lowest common ancestor of u, v?
+ # Extension: Make this query from any root vertex you want.
+ if root is not None:
+ # Custom root -- see diagrams below for reasoning.n
+ a = self.query(u, v)n
+ b = self.query(u, root)n
+ c = self.query(v, root)n
+ # Case 1: root is in the same component as u when `a` is removed from the tree. So `b` is the LCAn
+ if (a == c and c != b) return bn
+ # Case 2: root is in the same component as v when `a` is removed from the tree. So `a` is the LCAn
+ if (a == b and c != b) return cn
+ # Case 3: b and c are above a in the tree. So return an
+ return an
+ # assume that u is higher up than v, to simplify the code below
+ if self.level[u] > self.level[v]:
+ u, v = v, u
+ # STEP 1: set u and v to be ancestors with the same level
+ for k in range(self.MAX_LOG-1, -1, -1):
+ if (self.level[v] - (1 << k) >= self.level[u]):
+ # If v is 2^k levels below u, move it up 2^k levels.
+ v = self.ancestor[v][k]
+ # We can be certain that level[u] = level[v]. Reason: binary representation of all natural numbers.
+ # Do we need to move to step 2?
+ if (u == v): return u
+ # STEP 2: find the highest ancestor where u != v. Then the parent is the LCA
+ for k in range(self.MAX_LOG-1, -1, -1):
+ if (self.ancestor[u][k] != self.ancestor[v][k]):
+ # Move up 2^k steps
+ u = self.ancestor[u][k]
+ v = self.ancestor[v][k]
+ return self.parent[u]
+
+ def dist(self, u, v):
+ # Find the distance between two vertices
+
+ return self.length[u] + self.length[v] - 2 * self.length[self.query(u, v)]n
+
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+
template<typename T = int> struct LCA {
+ // vertices are represented as numbers 0->n-1.
+ // number such that 2^{MAX_LOG} > n. 20 works for n <= 10^6.
+ int MAX_LOG = 20;
+ int n; vector<vector<pair<int, T> > > adjacent;
+ vi parent, level;
+ vvi ancestor;
+ vector<T> length;
+
+ LCA(int n_vertices) : n(n_vertices), adjacent(n), parent(n), level(n), length(n) { }
+
+ void add_edge(int u, int v, T weight=1) {
+ adjacent[u].emplace_back(v, weight);
+ adjacent[v].emplace_back(u, weight);
+ }
+
+ void dfs(int source, int c_parent, int c_level, T c_length) {
+ // Search from the source down the tree and set parent, level, length accordingly.
+ parent[source] = c_parent;
+ level[source] = c_level;
+ length[source] = c_length;
+ for (auto v: adjacent[source])
+ if (v.first != c_parent)
+ dfs(v.first, source, c_level+1, c_length+v.second);
+ }
+
+ void build(int root=0) {
+ // Once edges are added, build the tree/data structure.
+ dfs(root, -1, 0, 0);
+ // Compute ancestor
+ ancestor.assign(n, vi(MAX_LOG, -1));
+ // Initial step: ancestor[v][0] = parent[v]
+ for (int v=0; v<n; v++)
+ ancestor[v][0] = parent[v];
+ // Now, compute ancestor[v][k] from 1->MAX_LOG
+ for (int k=1; k < MAX_LOG; k++)
+ for (int v=0; v<n; v++)
+ if (ancestor[v][k-1] != -1) {
+ // Move 2^{k-1} up, then 2^{k-1} again.
+ ancestor[v][k] = ancestor[ancestor[v][k-1]][k-1];
+ }
+ }
+
+ int query(int u, int v, int root=-1) {
+ // What is the lowest common ancestor of u, v?
+ // Extension: Make this query from any root vertex you want.
+ if (root != -1) {
+ // Custom root -- see diagrams below for reasoning.n
+ int a = query(u, v);n
+ int b = query(u, root);n
+ int c = query(v, root);n
+ // Case 1: root is in the same component as u when `a` is removed from the tree. So `b` is the LCAn
+ if (a == c and c != b) return b;n
+ // Case 2: root is in the same component as v when `a` is removed from the tree. So `a` is the LCAn
+ if (a == b and c != b) return c;n
+ // Case 3: b and c are above a in the tree. So return an
+ return a;n
+ }
+ // assume that u is higher up than v, to simplify the code below
+ if (level[u] > level[v]) swap(u, v);
+ // STEP 1: set u and v to be ancestors with the same level
+ for (int k=MAX_LOG-1, k>=0; k--)
+ if (level[v] - (1 << k) >= level[u]) {
+ // If v is 2^k levels below u, move it up 2^k levels.
+ v = ancestor[v][k];
+ }
+ // We can be certain that level[u] = level[v]. Reason: binary representation of all natural numbers.
+ // Do we need to move to step 2?
+ if (u == v) return u
+ // STEP 2: find the highest ancestor where u != v. Then the parent is the LCA
+ for (int k=MAX_LOG; k>=0; k--)
+ if (ancestor[u][k] != ancestor[v][k]) {
+ // Move up 2^k steps
+ u = ancestor[u][k];
+ v = ancestor[v][k];
+ }
+ return parent[u];
+ }
+
+ T dist(int u, int v) {
+ // Find the distance between two vertices
+ return length[u] + length[v] - 2 * length[query(u, v)];n
+ }
+}
+
And that’s our implementation done! Now get out there and solve some problems!
Related Problems
This post is licensed under GNU GPL V3 by the author. Comments powered by Disqus.
diff --git a/posts/mod/index.html b/posts/mod/index.html
new file mode 100644
index 0000000..42d0c96
--- /dev/null
+++ b/posts/mod/index.html
@@ -0,0 +1,89 @@
+ Modular Arithmetic | Monash Code Binder Modular Arithmetic
What is it?
Modular Arithmetic encompasses all sorts of theorems and optimizations surrounding the %
operator in C and Python.
As you’ll see in the related problems, modulo arithmetic is often tied together in a question regarding counting things (combinatorics) or probabilities.
For those who haven’t come up across it before, %
is an operation normally performed on two integers \(a\) and \(b\), and \(a\ \%\ b\) can be intuitively though of as the remainder when you divide \(a\) by \(b\).
For positive integers, this gives us the relation
\[a = a\ //\ b + a\ \%\ b\]Where \(//\) denotes floor division.
C++ vs Python
The %
operator actually does slightly different things in Python and C++. In particular, when the left hand side is a negative number, while Python will always return a positive number, C++ will always return a negative number!
1
+
print(-4 % 3)
+
1
+2
+3
+
int main() {
+ cout << (-4 % 3) << endl;
+}
+
The Python code will output 2, while the C++ code will output -1.
Characteristics
While %
at first might not seem to have much use, it does have some very interesting characteristics.
For most of this we’ll be assuming that both numbers are positive, just because most contest problems do and its easier to wrap your head around.
Addition, Subtraction, Multiplication
The %
operator can have its ordered swapped with any of the above operations. In particular:
\[\begin{eqnarray} (a + b)\ \%\ c &= ((a\ \%\ c) + (b\ \%\ c))\ \%\ c,\nonumber \\ (a - b)\ \%\ c &= ((a\ \%\ c) - (b\ \%\ c))\ \%\ c,\nonumber \\ (a * b)\ \%\ c &= ((a\ \%\ c) * (b\ \%\ c))\ \%\ c. \end{eqnarray}\]Multiples
\(a\ \%\ b = 0\) iff \(b\) is a factor of \(a\).
Inverse
For some particular \(a\) and \(b\), with \(b\) prime, let \(a^{-1}\) denote the modular inverse of \(a\). The modular inverse of \(a\) is the only number between \(1\) and \(b-1\) such that \(a*a^{-1}\ \%\ b = 1\).
As long as \(b\) is prime (or at least \(a, b\) are coprime), this inverse will always exist and there will always only be one of them.
We’ll come back to inverses later because they pop up in a few questions, often regarding probabilities.
Fermat’s Little Theorem
For prime \(p\), we know the following is true for any \(a\):
\[a^p\ \%\ p = a\ \%\ p.\]In particular, we also have
\[a^{p-2} \times a\ \%\ p = 1,\]so \(a^{p-2}\) is \(a^{-1}\).
Computing things
Exponentials
The most common application of modular arithmetic is simply because the expected output would normally be way too large to store in a long long
or something similar, and so the question asks to output the answer modulo (%
) 100000007 or some other number (This one happens to be prime).
To compute \(a^b\ \%\ m\), you might think this requires us to do \(O(b)\) calculation, but in fact we can do it in \(O(\log(B))\).
1
+2
+3
+4
+5
+6
+7
+8
+9
+
def expmod(a, b, m):
+ res = 1 % m
+ a %= m
+ while b:
+ if (b & 1):
+ res = (res * a) % m
+ a = (a*a) % m
+ b //= 2
+ return res
+
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+
typedef __int128 big;
+
+big expmod(big a, big b, big m) {
+ big res=1%m;
+ a %= m;
+ for (; b; b /= 2) {
+ if (b&1) {
+ res=(res*a)%m;
+ }
+ a=(a*a)%m;
+ }
+ return res;
+}
+
This is just a modification to the normal integer exponent algorithm, utilising the fact that we can decompose
\[a^b = \prod_{i=0}^k a^{b_i2^i},\quad b = \sum_{i=0}^k b_i2^i.\]The binary decomposition of \(b\).
Modular Inverse
1
+2
+3
+4
+5
+6
+7
+
# works for any a, m coprime
+def inv(a, m):
+ _, x, y = gcd(m, a)
+ return y % m
+# works for m prime
+def inv(a, m):
+ return expmod(a, m-2, m)
+
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+
// works for any a, m coprime
+ll inv(ll a, ll m) {
+ ll x, y;
+ gcd(m, a, x, y);
+ // Ensure outcome is positive. CPP exclusive.
+ return ((y % m) + m) % m;
+}
+// works for m prime
+ll inv(ll a, ll m) {
+ return expmod(a, m-2, m);
+}
+
Here gcd is the function that takes two numbers \(a\) and \(b\), and returns:
- The greatest common factor of \(a\) and \(b\)
- Two numbers \(x\) and \(y\) such that \(ax + by = \text{gcd}(a, b)\).
This works because \(ya = \text{gcd}(m, a) - xm\). But, for \(m\) prime, we have \(\text{gcd}(m, a) = 1\):
\[ya\ \%\ m = (1 - xm)\ \%\ m = 1.\]And so \(y\ \%\ m\) has to be the multiplicative inverse of \(a\).
Fractions and Modular inverses
The final, and often most crucial trick goes as follows:
A common contest problem ends with the following statement:
The output can be expressed as an irreducible fraction \(\frac{p}{q}\). Output \(pq^{-1}\) modulo 100000007.
While this can seem daunting, \(pq^{-1}\) has some nice properties. Let’s see them:
Properties
Addition
Consider two fractions \(\frac{a}{b}\) and \(\frac{c}{d}\). We have:
\[\frac{a}{b} + \frac{c}{d} = \frac{ad + bc}{bd},\]and (modulo \(m\)):
\[ab^{-1} + cd^{-1} = b^{-1}d^{-1}(ad + bc) = (ad + bc)(bd)^{-1}.\]Multiplication
Consider two fractions \(\frac{a}{b}\) and \(\frac{c}{d}\). We have:
\[\frac{a}{b} \times \frac{c}{d} = \frac{ac}{bd},\]and (modulo \(m\)):
\[ab^{-1} \times cd^{-1} = acb^{-1}d^{-1} = ac(bd)^{-1}.\]Factorisation
Consider for the last time a single reducible fraction \(\frac{ka}{kb}\). We have:
\[(ka)(kb)^{-1} = kak^{-1}b^{-1} = ab^{-1}.\]Wrap up
So, given the above, rather than having to store \(\frac{a}{b}\) for ludicrously sized \(a\) and \(b\), we can instead compute \(ab^{-1}\) and do arithmetic with these values.
As an example, if we wanted to compute \(\frac{p}{q} = (\frac{a}{b} + \frac{c}{d}) \times \frac{e}{f}\), then we could instead output
\[(ab^{-1} + cd^{-1}) \times ef^{-1}\ \%\ m.\]How cool is that!
Related Problems
This post is licensed under GNU GPL V3 by the author. Comments powered by Disqus.
diff --git a/posts/problems-21-s2-c1/index.html b/posts/problems-21-s2-c1/index.html
new file mode 100644
index 0000000..f2210a4
--- /dev/null
+++ b/posts/problems-21-s2-c1/index.html
@@ -0,0 +1,559 @@
+ Challenge Problems - 2021 Sem 2, Contest 1 | Monash Code Binder Challenge Problems - 2021 Sem 2, Contest 1
Sports Loans
Statement
Andrew is head of the sports club, and manages the inventory. Part of Andrew’s job is loaning footballs to people, and collecting those footballs once they have been used.
At the start of the day, Andrew has \(r\) footballs in stock, and knows that \(p+q\) people will approach him over the course of the day. \(p\) people will request a football, while \(q\) people will return a football. What Andrew does not know is the order in which these people will approach him.
Of course, Andrew wants to be able to give a football to everyone who requests one, when they request one. So for example if the first \(r+1\) people want a football, Andrew can’t give a football to the last person.
Andrew wants to know the probability that the above situation does not occur today, in other words the probability that every time someone requests a football, Andrew has one in stock.
Input / Output
Input will consist of three space separated integers, \(p, q\) and \(r\), as defined in the problem statement. Output should be a single number, the probability that Andrew will always be able to give a football to anyone who requests it. This number should have absolute error less than \(10^{-8}\).
Example Run
Input
1
+
4 1 3
+
Output
1
+
0.8
+
Since there is a 20% chance that the “return a football” event occurs before all 4 “request a football” events, which would cause problems.
Hints / Solution
Hint 1
Simulating won’t be enough, because of the sizes of p
, q
and r
. What you should instead do is try to find a general form for the probability, based on p
, q
, r
.
Start by trying to come up with a visualisation of this process in 2-dimensional space. What do good/bad orderings of people look like?
Hint 2
The visualisation we are looking for is one where we begin at point \((r, 0)\), and for each person, we either move to the right 1 unit (person returns a football), or up 1 unit (person requests a football).
A run is then invalid when we cross (not touch) the line \(y = x\). For an invalid run, what happens when we flip all points across the line \(y = x + 1\) before the intersection?
What is the total number of paths from this new starting point to \((q+r, p)\)?
Solution
To handle the easy cases, if \(p \leq r\), then the probability is 1 (We can’t possibly run out of footballs). If \( p > r + q\), then the probability is 0 (We can’t possibly handle everyones request). We assume neither of these is the case for the following discussion.
Let’s first view the problem through a different lens, to make the solution a bit more natural. Imagine Andrew is a point on the 2D plane. Whenever a person approaches him to give him a football, he moves one unit in the positive \(x\)-axis, and whenever a person approaches him to request a football, he moves one unit in the positive \(y\)-axis.
We can think of this as the \(x\)-axis representing footballs returned, while the \(y\) axis represents footballs taken. Since we start with \(r\) footballs, it might make sense to start Andrew at the position \((r, 0)\). Then, when Andrew is on the line \(y = x\), we know that Andrew has exactly 0 balls in inventory. Therefore we want to know the probability that Andrew never dips above this line (or equivalently, that Andrew never touches the line \(y = x + 1\)).
Since any ordering of the $p+q$ people is equally likely, we can simply count all possible distinct paths from \((r, 0)\) to \((r+q, p)\), and the proportion of these paths which sit below the line \(y = x+1\) is the probability we want. The number of all possible paths is \({p+q \choose p} = {p+q \choose q}\), since we can just pick the locations of the \(p\) (or \(q\)) people in our ordering.
Now, counting the number of paths that avoid the line is tough, but we can do something similar by finding a bijection between invalid paths and some other collecion.
Rather than considering paths from \((r, 0)\) to \((r+q, p)\), what if we instead started from the same point, reflected on the line \(y = x+1\)? Then we’d be looking at paths from \((-1, r+1)\) to \((r+q, p)\). Note that every path between these two points needs to touch the line \(y = x+1\). Furthermore, we can turn each of these paths into an invalid path from \((r, 0)\) to \((r+q, p)\) in the following way:
- Find the first intersection of the path with the line \(y = x + 1\) (Some intersection must exist).
- Mirror the path along the line \(y = x + 1\) before this intersection.
Since \((-1, r+1)\) is the mirrored version of \((r, 0)\), all of these new paths are distinct paths from \((r, 0)\) to \((r+q, p)\). Furthermore, since the original paths hit the line, each of the these mirrored paths also hit the line and are therefore invalid. Hopefully it’s also easy to see that every possible invalid path can be reached via this mirror method.
Therefore, the total number of invalid paths is equal to the total number of any of path between \((-1, r+1)\) and \((r+q, p)\). By the same argumentation as before, there are \({(r+q+1) + (p-r-1) \choose r+q+1} = {p+q \choose p - r - 1}\) possible paths.
Now that we know how many paths there are in total, and how many paths are invalid, we can calculate some probabilities. The probability that Andrew does run into this situation (That we have a bad path) is:
\[P(\text{bad}) = \frac{p+q \choose p - r - 1}{p+q \choose p} = \frac{(p+q)!p!q!}{(p+q)!(p-r-1)!(q+r+1)!} = \frac{\prod^r_{i=0}p-i}{\prod^r_{i=0}q+i+1}.\]The probability the question asks for is then just \(P(\text{good}) = 1 - P(\text{bad})\). Note that the cancellation above is required to fit within precision and time limits, as we can’t compute \(p!\) for large enough \(p\) within time.
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+
# Read 3 ints
+p, q, r = list(map(int, input().split()))
+if p <= r:
+ print(1)
+elif p > r + q:
+ print(0)
+else:
+ p_bad = 1
+ for i in range(r+1):
+ p_bad *= (p-i) / (q+i+1)
+ # Be safe with 10 precision points.
+ print(f"{1-p_bad:.10f}")
+
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+
#include <iostream>
+#include <iomanip>
+
+using namespace std;
+
+int main() {
+
+ // p, q can be large, use long long.
+ long long p, q, r;
+
+ cin >> p >> q >> r;
+
+ if (p <= r) {
+ cout << 1 << endl;
+ } else if (p > r + q) {
+ cout << 0 << endl;
+ } else {
+ double p_bad = 1;
+
+ for (int i=0; i<= r; i++) {
+ p_bad = p_bad * ((double)(p-i)) / ((double)(q+i+1));
+ }
+
+ cout << setprecision(10) << fixed << 1 - p_bad << endl;
+ }
+
+ return 0;
+}
+
Optimal Farming
Statement
Amy has just bought a farm in the outback, and wants to start selling tomatoes. Some of the crops in the farm are already tomatoes, but there are others she wants to get rid of and replace with tomatoes.
Amy has employed the help of Square Tomatoes Group™. Amy can pay the group $\(s\) to plant an \(s \times s\) grid of crops with tomatoes (It doesn’t matter if the existing crop was tomatoes or not, all grid squares become tomatoes. This square can also exceed Amy’s farm). Amy wants to minimise her cost to Square Tomatoes Group™ such that all crops are tomatoes.
Input / Output
Input starts with two integers \(1 \leq l, w \leq 30\), the length and width of the farm, separated by space. Input then contains \(l\) lines, each containing a string of \(w\) characters. Each of these characters represent a grid square in the farm. This square is a tomato crop if and only if the character printed is a T
.
Output should be a single integer, the minimum Amy has to pay to fill her farm with tomato crops.
Example Test
Input
1
+2
+3
+4
+
3 4
+PWTT
+TCTT
+TTTL
+
Output
1
+
3
+
Explanation: We can pay Square Tomatoes Group™ \($2\) to plant tomatoes in the top-left 2x2 area, and then \($1\) to plant tomatoes in the final bottom-right square.
Hints / Solution
Hint 1
Note that if every row and column has a square with no tomato, then the answer is rather obvious. The problem only gets interesting when an entire row/column is already tomatoes.
Hint 2
Suppose we don’t go for the easy solution of just using a massive square to cover our farm, and have a cheaper solution. Then one of the rows or columns in the farm must be untouched. How can we recurse from here?
Solution
We will generalise and compute \(\text{cost}(x1, x2, y1, y2)\), the cost of converting the rectangle \([x1, x2), [y1, y2)\) all to tomatoes. The question is asking us to compute \(\text{cost}(0, w, 0, l)\).
Note that we can always spend \($\text{max}(x2-x1, y2-y1)\) and cover the rectangle, by using a square that exceeds the bounds. Also, note that if a collection of squares overlaps every column of the rectangle, then the cost of planting all of these squares must be at least \(x2-x1\), and similarly if every row of the rectangle is overlapped by a square, the minimum cost is \(y2-y1\).
With this in mind, suppose there was a cheaper selection of squares that converts this entire rectangle to tomatoes. Then from the logic above, there must be some column or row which is not touched by these planted squares (A column or row that is already all tomatoes). This column/row separates our rectangle in two, and so we can solve the subproblem of \(\text{cost}\) on each of these rectangles.
For example, in the input given, we have a \(3 \times 4\) rectangle. The easy solution is to cover the entire thing with a \(4 \times 4\) square, costing \($4\).
But, column 2 is all tomatoes, so we can solve the subproblem on the left and right hand sides of this column, and see if doing this is cheaper. Continuing along, we find the left subproblem costs \($2\), and the right subproblem costs \($1\), and so the final result is \($3\).
We can achieve this within the time limit with Dynamic Programming.
The recursive definition is:
\[\text{cost}(x1, x2, y1, y2) := \text{min} \begin{cases} (x2 - x1) \times (y2 - y1) &\\ \text{cost}(x1, c, y1, y2) + \text{cost}(c, x2, y1, y2) & x1 < c < x2, \text{column } c \text{ from } y1 \to y2 \text{ all tomato.}\\ \text{cost}(x1, x2, y1, r) + \text{cost}(x1, x2, r, y2) & y1 < r < y2, \text{row } r \text{ from } x1 \to x2 \text{ all tomato.} \end{cases}\]The base case being that \(\text{cost}(x1, x1+1, y1, y1+1) = b\), where \(b\) is 0 if it’s a tomato plant, and 1 otherwise. In order to know when an entire segment of a row/column is all tomatoes, we can also use DP, by breaking up each square into individual rows / columns.
For both of these we have \(l^2w^2\) values to compute, and each of the values takes \(l + w\) operations to compute (In the recursive definition, we might recurse for every row and column in the square). So the total cost is about \(30^5 \approx 2 \times 10^7\)
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+
import sys
+# Maximum recursion for this problem is actually like 60, but better safe than sorry.
+sys.setrecursionlimit(10000)
+
+l, w = list(map(int, input().split()))
+
+grid = [input() for _ in range(l)]
+
+# is the rectangle from x1 to x2, y1 to y2 all tomato? (RHS exclusive)
+tomato_dp = [[[[None for _1 in range(31)] for _2 in range(31)] for _3 in range(31)] for _4 in range(31)]
+def all_tomato(x1, x2, y1, y2):
+ if tomato_dp[x1][x2][y1][y2] is not None:
+ return tomato_dp[x1][x2][y1][y2]
+ if x1 < x2 - 1:
+ tomato_dp[x1][x2][y1][y2] = all_tomato(x1, x2-1, y1, y2) and all_tomato(x2-1, x2, y1, y2)
+ elif y1 < y2 - 1:
+ tomato_dp[x1][x2][y1][y2] = all_tomato(x1, x2, y1, y2-1) and all_tomato(x1, x2, y2-1, y2)
+ else:
+ # y2 = y1+1, x2 = x1+1.
+ tomato_dp[x1][x2][y1][y2] = grid[x1][y1] == "T"
+ return tomato_dp[x1][x2][y1][y2]
+
+cost_dp = [[[[None for _1 in range(31)] for _2 in range(31)] for _3 in range(31)] for _4 in range(31)]
+def cost(x1, x2, y1, y2):
+ if cost_dp[x1][x2][y1][y2] is not None:
+ return cost_dp[x1][x2][y1][y2]
+ if x1 == x2 or y1 == y2:
+ # Empty rectangle.
+ return 0
+ cur_min = max(x2-x1, y2-y1)
+ # Otherwise, there is an empty row/column we can exclude. Simply solve this suproblem.
+ for c in range(x1, x2):
+ if all_tomato(c, c+1, y1, y2):
+ cur_min = min(cost(x1, c, y1, y2) + cost(c+1, x2, y1, y2), cur_min)
+ for r in range(y1, y2):
+ if all_tomato(x1, x2, r, r+1):
+ cur_min = min(cost(x1, x2, y1, r) + cost(x1, x2, r+1, y2), cur_min)
+ cost_dp[x1][x2][y1][y2] = cur_min
+ return cost_dp[x1][x2][y1][y2]
+
+print(cost(0, l, 0, w))
+
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+
#include <iostream>
+#include <string>
+
+#define FOR(i,j,k) for(int i=j; i<k; i++)
+#define MAX(a, b) (a > b) ? a : b
+#define MIN(a, b) (a < b) ? a : b
+#define MAXN 30
+
+using namespace std;
+
+int l, w;
+
+const int UNKNOWN = -1;
+
+int DP_TOMATO[MAXN+1][MAXN+1][MAXN+1][MAXN+1];
+int DP_COST[MAXN+1][MAXN+1][MAXN+1][MAXN+1];
+string grid[MAXN+1];
+
+bool tomato(int x1, int x2, int y1, int y2) {
+ // Is the rectangle [x1, x2), [y1, y2) already tomato?
+ if (DP_TOMATO[x1][x2][y1][y2] != UNKNOWN)
+ return DP_TOMATO[x1][x2][y1][y2];
+ if (x1 < x2 - 1) {
+ // Look at the column x=x2-1 separately
+ DP_TOMATO[x1][x2][y1][y2] = tomato(x1, x2-1, y1, y2) && tomato(x2-1, x2, y1, y2);
+ } else if (y1 < y2 - 1) {
+ // Look at the row y=y2-1 separately
+ DP_TOMATO[x1][x2][y1][y2] = tomato(x1, x2, y1, y2-1) && tomato(x1, x2, y2-1, y2);
+ } else {
+ // We are a 1x1.
+ DP_TOMATO[x1][x2][y1][y2] = grid[x1][y1] == 'T';
+ }
+ return DP_TOMATO[x1][x2][y1][y2];
+}
+
+int cost(int x1, int x2, int y1, int y2) {
+ if (DP_COST[x1][x2][y1][y2] != UNKNOWN)
+ return DP_COST[x1][x2][y1][y2];
+ if (x1 == x2 || y1 == y2)
+ // Empty rectangle. Possible in the below recursion so just return 0.
+ return 0;
+ // We can always cover the rectangle by using a big square.
+ int cur_min = MAX(x2-x1, y2-y1);
+ FOR(c,x1,x2)
+ if (tomato(c, c+1, y1, y2))
+ // If this column is tomato, then we can try solving the two subproblems instead by removing the column.
+ cur_min = MIN(
+ cur_min,
+ cost(x1, c, y1, y2) + cost(c+1, x2, y1, y2)
+ );
+ FOR(r,y1,y2)
+ if (tomato(x1, x2, r, r+1))
+ // If this row is tomato, then we can try solving the two subproblems instead by removing the row.
+ cur_min = MIN(
+ cur_min,
+ cost(x1, x2, y1, r) + cost(x1, x2, r+1, y2)
+ );
+ DP_COST[x1][x2][y1][y2] = cur_min;
+ return DP_COST[x1][x2][y1][y2];
+}
+
+int main() {
+
+ cin >> l >> w;
+ FOR(i,0,l) {
+ cin >> grid[i];
+ }
+
+ FOR(x1,0,l+1)FOR(x2,0,l+1)FOR(y1,0,w+1)FOR(y2,0,w+1) {
+ DP_TOMATO[x1][x2][y1][y2] = UNKNOWN;
+ DP_COST[x1][x2][y1][y2] = UNKNOWN;
+ }
+
+ cout << cost(0, l, 0, w) << endl;
+
+ return 0;
+}
+
Repetitive Jugglers
Statement
Alice is the leader of a juggling crew, and they are set to perform a crazy juggling trick.
In this trick, every member of the crew starts off with a different coloured ball. Every member then picks another member of the crew (possibly themselves), let us call that member their receiver.
Then, the trick begins. Every second, every crew member will throw all of the balls they are holding to their designated receiver.
The trick only stops once everyone has the same ball they started with (Note that not always does this trick stop!)
Alice wants to know, given who has chosen who as receiver, whether the game will end, and if so, how many seconds this will take.
Input
Input will consist of two lines.
The first line will contain an integer \(n\), the number of members in the juggling crew.
The second line will then contain \(n\) space-separated integers. The \(i\)th integer represents the \(i\)th crew member’s pick for receiver.
(So we enumerate crew members \(1, 2, 3\ldots\), and if the second integer is \(1\), that means that crew member \(2\) has chosen \(1\) as their receiver.)
It is guaranteed that if the trick does stop, it will stop before \(10^{15}\) seconds have passed
Output
If the trick will never finish, print \(-1\). Otherwise, print the total length of the trick, in seconds.
Example Test
Input
1
+2
+
3
+2 1 3
+
Output
1
+
2
+
After 1 second, person 1 and person 2 throw the balls at each other, and person 3 throws the ball to themselves. As such person 1 is holding person 2’s ball, and person 2 is holding person 1’s ball. Person 3 is holding their own ball.
After 2 seconds, the same action occurs, and so everyone is holding their own ball.
Hints / Solution
Hint 1
Since this trick might continue for \(10^{15}\) seconds, we cannot simulate it (Especially with large \(n\)).
We need to figure out ahead of time when this will occur.
Notice that if one person receives 2 or more balls at any point in time, the trick will never end.
Hint 2
Only selections of “receivers” in which every member is the receiver of exactly one member will finish, and they will always finish.
For math inclined individuals, these receivers represent a permutation of the group, and we want to know how many repeated applications of this permutation are needed to take us back to the identity.
We can figure out how long this will take based on cycles that are present in the permutation.
Solution
Note that if any member recieves two juggling balls, then our sequence can never return to how it was. Therefore everyone must recieve a single ball every second. In other words, our \(n\) space-separated integers must be a permutation of the numbers \(1\) through to \(n\). Note that in a permutation, there are multiple distinct cycles of different sizes (For example 1 passes to 3 passes to 7 passes to 4 passes to 1). Notably, everyone in these cycles has their ball every \(k\) seconds, where \(k\) is the length of the cycle.
Therefore, if we have cycles of length \(k_1, k_2, \ldots k_a\), then the first time the entire sequence will repeat must be the least common multiple of these values \(k_1, k_2, \ldots k_a\) (The first number \(c\), which is divisible by all of \(k1, k2, \ldots, k_a\)).
So our solution just needs to find each of these cycles, and count their length. Then compute the least common multiple.
In the sample input, we have a permutation with 2 cycles (1 passes to 2 passes to 1, and 3 passes to 3). These cycles are of length 2 and 1 respectively.
Therefore every 2 seconds, members 1 and 2 will have their balls, and every second, member 3 will have their ball. Because of this, the answer is the smallest number which is divisible by 2 and 1 (2).
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+
n = int(input())
+# make it 0 -> n-1.
+choices = list(map(lambda x: int(x)-1, input().split()))
+
+# Greatest common divisor
+def gcd(a, b):
+ if b == 0:
+ return abs(a)
+ return gcd(b, a%b)
+
+# Least common multiple
+def lcm(a, b):
+ return a * b // gcd(a, b)
+
+# list(set()) will remove duplicates. If no duplicates, then we have a permutation
+if len(list(set(choices))) == n:
+ # Permutation
+ # Find the cycle lengths
+ lengths = []
+ found = [False]*n
+ for x in range(n):
+ # Is x not already in a cycle?
+ if not found[x]:
+ length = 0
+ # Search through the cycle.
+ cur = x
+ while not found[cur]:
+ found[cur] = True
+ length += 1
+ # Move to the next person
+ cur = choices[cur]
+ lengths.append(length)
+ # Print the lcm of all lengths.
+ cur_lcm = 1
+ for length in lengths:
+ # The lcm of a list is simply the pairwise lcm of each element, combined.
+ cur_lcm = lcm(cur_lcm, length)
+ print(cur_lcm)
+else:
+ # Not a permutation
+ print(-1)
+
+
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+
#include <vector>
+#include <iostream>
+#include <map>
+
+using namespace std;
+
+typedef long long ll;
+typedef vector<ll> vll;
+
+ll gcd(ll a, ll b) {
+ if (b == 0) return a;
+ return gcd(b, a%b);
+}
+
+ll lcm(ll a, ll b) {
+ return (a / gcd(a, b)) * b;
+}
+
+int main() {
+
+ int n;
+ cin >> n;
+
+ vll nums(n);
+ for (int i=0; i<n; i++) {
+ cin >> nums[i];
+ // Make it 0 -> n-1.
+ nums[i]--;
+ }
+
+ // Check for duplicates
+ bool bad = false;
+ map<int, int> dup_check;
+ for (int i=0; i<n; i++) {
+ dup_check[i] = 0;
+ }
+ for (int i=0; i<n; i++) {
+ dup_check[nums[i]]++;
+ if (dup_check[nums[i]] > 1) {
+ // Someone recieves 2.
+ bad = true;
+ }
+ }
+ if (bad) {
+ cout << -1 << endl;
+ } else {
+ vll lengths;
+ vll found(n, false);
+ for (int i=0; i<n; i++) {
+ if (!found[i]) {
+ int length = 0;
+ int cur = i;
+ while (!found[cur]) {
+ found[cur] = true;
+ length++;
+ cur = nums[cur];
+ }
+ lengths.push_back(length);
+ }
+ }
+ ll cur_lcm = 1;
+ for (ll l: lengths) {
+ cur_lcm = lcm(cur_lcm, l);
+ }
+ cout << cur_lcm << endl;
+ }
+
+ return 0;
+}
+
This post is licensed under GNU GPL V3 by the author. Comments powered by Disqus.
diff --git a/posts/uf/index.html b/posts/uf/index.html
new file mode 100644
index 0000000..ae7df79
--- /dev/null
+++ b/posts/uf/index.html
@@ -0,0 +1,871 @@
+ Union Find / DSU | Monash Code Binder Union Find / DSU
Where is this useful?
In many problems, translating into a graph structure can prove helpful, as we can describe our problem in very abstract terms.
Once you’ve translated into this graph structure, often you might want to know whether two vertices are connected via a path, and if this is not the case, what two separate components they come from. Union Find allows us to not only answer this question, but slowly add edges to the graph and still answer these queries fast.
As such, Union Find is useful in any problem where connections are incrementally being added to some structure, and along the way you need to query what vertices are connected.
Implementing the Data Structure
Basics
Let’s first define the interface for our Union Find. We want to provide the ability to merge two vertices, and we should be able to query two vertices, asking if they are connected.
At first, every vertex is disconnected. We can add edges later as need be.
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+
class UnionFind:
+ """
+ vertices are represented as numbers 0->n-1.
+ """
+
+ def __init__(self, n):
+ self.n = n
+
+ def merge(self, a, b) -> bool:
+ # Merge the two vertices a and b. Return a boolean which is true if they weren't already merged.
+ pass # TODO
+
+ def connected(self, a, b) -> bool:
+ # Whether the two vertices a and b are connected.
+ pass # TODO
+
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+
struct UnionFind {
+ // vertices are represented as numbers 0->n-1.
+ int n;
+
+ UnionFind(int n_verts) : n(n_verts) { }
+
+ bool merge(int a, int b) {
+ // Merge the two vertices a and b. Return a boolean which is true if they weren't already merged.
+ // TODO
+ }
+
+ bool connected(int a, int b) {
+ // Whether the two vertices a and b are connected.
+ // TODO
+ }
+};
+
Now, we can take our first approach at the data structure. Notice that before any merging occurs, each component is uniquely identified by a single vertex contained within. As we merge our vertices, we’ll try keep it that way.
In order to do this, we can model each component as a rooted tree. The root of this tree is the identifier, and so from any vertex in the tree, we can get to the identifier by moving up the tree.
To merge two components (trees), we simply place the second tree as a child of the first tree. The second root no longer identifies a component, and the first root is now the identifier of the combined component.
So, to implement this, we’ll create a parent array, which contains the parent of each vertex. For vertices that are the root, they will be their own parents.
We will also need a method to find the identifier of any component, by moving up the tree. We will do this with find
in the code.
And we can already get around to implementing connected
and merge
. For connected
, a
and b
are in the same component if the identifier of their components are the same. For merge, we simply need to modify the parent
attribute of one identifier, so that it points to the root of the other component:
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+
class UnionFind:
+ """
+ vertices are represented as numbers 0->n-1.
+ """
+
+ def __init__(self, n):
+ self.n = n
+ # parent[x] = x to begin with.n
+ self.parent = list(range(n))n
+
+ def find(self, a):n
+ # Find the root of this componentn
+ if self.parent[a] == a:n
+ return an
+ return self.find(self.parent[a])n
+
+ def merge(self, a, b) -> bool:
+ # Merge the two vertices a and b. Return a boolean which is true if they weren't already merged.
+
+ a = self.find(a)n
+ b = self.find(b)n
+ if a == b:n
+ return Falsen
+ self.parent[b] = an
+ return Truen
+
+ def connected(self, a, b) -> bool:
+ # Whether the two vertices a and b are connected.
+
+ a = self.find(a)n
+ b = self.find(b)n
+ return a == bn
+
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+
struct UnionFind {
+ // vertices are represented as numbers 0->n-1.
+ int n;
+ vector<int> parent;n
+
+ UnionFind(int n_verts) : n(n_verts), parent(n_verts) {m
+ iota(parent.begin(), parent.end(), 0);n
+ }
+
+ int find(int a) {n
+ // Find the root of this componentn
+ if (parent[a] == a) return a;n
+ return find(parent[a]);n
+ }n
+
+ bool merge(int a, int b) {
+ // Merge the two vertices a and b. Return a boolean which is true if they weren't already merged.
+ a = find(a);n
+ b = find(b);n
+ if (a == b) return false;n
+ parent[b] = a;n
+ return true;n
+ }
+
+ bool connected(int a, int b) {
+ // Whether the two vertices a and b are connected.
+ a = find(a);n
+ b = find(b);n
+ return a == b;n
+ }
+};
+
Useful data
A keen eye might’ve spotted that there’s possibility of some bad complexity coming out of these methods. If components are merged badly (So that we have a very unbalanced tree) we can make it so that find
(and therefore merge/connected
) are O(n)
complexity. To improve this, and to make the data structure more useful as a whole, let’s take a quick detour and try to include some other data as part of our data structure:
size
: This should be an array which stores the size of each component. The size
entry for non-identifier vertices doesn’t matter and can be left with old data.rank
: This should be an array which stores the maximum depth of any component tree. The rank
entry for non-identifier vertices doesn’t matter and can be left with old data.
It could be a good bit of practice to try this yourself; Modify the methods above to store and update the size
and rank
values.
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+
class UnionFind:
+ """
+ vertices are represented as numbers 0->n-1.
+ """
+
+ def __init__(self, n):
+ # Number of components
+ self.n = n
+ # parent[x] = x to begin with.
+ self.parent = list(range(n))
+ # size = number of vertices in componentn
+ self.size = [1]*nn
+ # rank = max-depth of component treen
+ self.rank = [1]*nn
+
+ def find(self, a):c
+ # Find the root of this component
+ if self.parent[a] == a:
+ return a
+ return self.find(self.parent[a])
+c
+ def merge(self, a, b) -> bool:
+ # Merge the two vertices a and b. Return a boolean which is true if they weren't already merged.
+ a = self.find(a)
+ b = self.find(b)
+ if a == b:
+ return False
+ self.size[a] += self.size[b]n
+ self.parent[b] = a
+ self.rank[a] = max(self.rank[a], self.rank[b])n
+ if self.rank[a] == self.rank[b]:n
+ self.rank[a] += 1n
+ self.n -= 1n
+ return True
+
+ def connected(self, a, b) -> bool:c
+ # Whether the two vertices a and b are connected.
+ a = self.find(a)
+ b = self.find(b)
+ return a == b
+c
+ def size_component(self, a):n
+ # Find the size of a particular component.n
+ # Question: Why do we need to call `find`?n
+ return self.size[self.find(a)]n
+
+ def num_components(self):n
+ return self.nn
+
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+
struct UnionFind {
+ // vertices are represented as numbers 0->n-1.
+ int n;
+ vector<int> parent, size, rank;
+
+ UnionFind(int n_verts) : n(n_verts), parent(n_verts), size(n_verts, 1), rank(n_verts, 1) {m
+ iota(parent.begin(), parent.end(), 0);
+ }
+
+ int find(int a) {c
+ // Find the root of this component
+ if (parent[a] == a) return a;
+ return find(parent[a]);
+ }c
+
+ bool merge(int a, int b) {
+ // Merge the two vertices a and b. Return a boolean which is true if they weren't already merged.
+ a = find(a);
+ b = find(b);
+ if (a == b) return false;
+ size[a] += size[b];n
+ parent[b] = a;
+ rank[a] = max(rank[a], rank[b]);n
+ if (rank[a] == rank[b]) rank[a]++;n
+ n--;n
+ return true;
+ }
+
+ bool connected(int a, int b) {c
+ // Whether the two vertices a and b are connected.
+ a = find(a);
+ b = find(b);
+ return a == b;
+ }c
+
+ int size_component(int a) {n
+ // Find the size of a particular component.n
+ // Question: Why do we need to call `find`?n
+ return size[find(a)];n
+ }n
+
+ int num_components() { return n; }n
+};
+
If the maximum of rank[a]
and rank[b]
is equal to rank[b]
, then the total depth in the tree will be at most rank[b]+1
, since we must include the path from a
to b
, before considering any children of b
.
Armed with this information, we can make some better decisions when it comes to merging, and also start compressing the trees.
Depth reduction
Since we get bad complexity when merging trees with large rank as children, let’s always pick the largest rank tree to be the identifier. Then the overall rank of the resultant tree only increases if the rank of the two original trees was the same.
Additionally, every time we call find, we are traversing up our tree. But in this traversal, it is very cheap to simply connect every vertex along the way to the root vertex, using recursion.
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+
class UnionFind:
+ """
+ vertices are represented as numbers 0->n-1.
+ """
+
+ def __init__(self, n):c
+ # Number of components
+ self.n = n
+ # parent[x] = x to begin with.
+ self.parent = list(range(n))
+ # size = number of vertices in component
+ self.size = [1]*n
+ # rank = max-depth of component tree
+ self.rank = [1]*n
+c
+ def find(self, a):
+ # Find the root of this component
+ if self.parent[a] == a:
+ return a
+ # Whenever I call find, set the parent to be right above me.n
+ b = self.find(self.parent[a])n
+ self.parent[a] = bn
+ return bn
+
+ def merge(self, a, b) -> bool:
+ # Merge the two vertices a and b. Return a boolean which is true if they weren't already merged.
+ a = self.find(a)
+ b = self.find(b)
+ if a == b:
+ return False
+ if (self.rank[a] < self.rank[b]):n
+ a, b = b, an
+ self.size[a] += self.size[b]
+ self.parent[b] = a
+ if self.rank[a] == self.rank[b]:
+ self.rank[a] += 1
+ self.n -= 1
+ return True
+
+ def connected(self, a, b) -> bool:c
+ # Whether the two vertices a and b are connected.
+ a = self.find(a)
+ b = self.find(b)
+ return a == b
+c
+ def size_component(self, a):c
+ # Find the size of a particular component.
+ # Question: Why do we need to call `find`?
+ return self.size[self.find(a)]
+c
+ def num_components(self):c
+ return self.n
+c
+
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+
struct UnionFind {
+ // vertices are represented as numbers 0->n-1.
+ int n;
+ vector<int> parent, size, rank;
+
+ UnionFind(int n_verts) : n(n_verts), parent(n_verts), size(n_verts, 1), rank(n_verts, 1) {
+ iota(parent.begin(), parent.end(), 0);
+ }
+
+ int find(int a) {
+ // Find the root of this component
+ if (parent[a] == a) return a;
+ // Whenever I call find, set the parent to be right above me.n
+ return parent[a] = find(parent[a]);m
+ }
+
+ bool merge(int a, int b) {
+ // Merge the two vertices a and b. Return a boolean which is true if they weren't already merged.
+ a = find(a);
+ b = find(b);
+ if (a == b) return false;
+ if (rank[a] < rank[b]) swap(a, b);n
+ size[a] += size[b];
+ parent[b] = a;
+ if (rank[a] == rank[b]) rank[a]++;
+ n--;
+ return true;
+ }
+
+ bool connected(int a, int b) {c
+ // Whether the two vertices a and b are connected.
+ a = find(a);
+ b = find(b);
+ return a == b;
+ }c
+
+ int size_component(int a) {c
+ // Find the size of a particular component.
+ // Question: Why do we need to call `find`?
+ return size[find(a)];
+ }c
+
+ int num_components() { return n; }
+};
+
Complexity Analysis
And that is all the changes required to reduce the complexity of union find, but how much has it done?
Well, to construct a rank 2 tree we need to merge 2 rank 1 trees, to construct a rank 3 tree we need to merge 2 rank 2 trees, and so on and so forth. Therefore in a union find with n vertices, we have at most log2(n) rank on each tree in our data structure.
This means that find
is log(n), meaning both merge
and connected
are also log(n). (In fact, with the path compression above, the complexity is even less (inverse ackermann), but this isn’t super important under contest conditions)
And that’s the data structure fully taken care of. Now let’s solve some problems!
A simple application
Let’s try our hand at Friend Circle. Give it a shot yourself before reading the discussion below!
(Note: The time bounds for this problem are very small. Python will probably TLE. But give it a shot anyways!).
Hint
While the problem description is a bit sparse, hopefully you can spot that we care about what group of friends are connected by some friendship (If A and B are friends, and B and C are friends, then all 3 form a circle of friends, no need for A and C to be friends.)
So, if we let every person be a vertex in our graph, with edges representing friendship, then Union Find is exactly the tool we need. Before we get into coding we need only ask ourselves two things:
- What is the maximum size of our Union Find
n
? - How will I turn people’s names into the digits
0
to n-1
?
Solution
To answer 1, the maximum number of people is simply 2 times the total number of connections. For 2, we can use a dictionary/map to map strings to integers. To ensure every person is unique from 0
to n-1
, we can start a counter at 0
, and every time we see a new name, increment this counter. Then the old value of the counter is the id for that person:
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+
from collections import defaultdict
+t = int(input())
+
+for _ in range(t):
+ connections = int(input())
+ max_people = 2 * connectionsb
+ uf = UnionFind(max_people)b
+ cur_counter = 0
+ def count_increase():
+ global cur_counter
+ cur_counter += 1
+ return cur_counter - 1
+ # The defaultdict now assigns a new id to every new person mentioned.
+ person_map = defaultdict(count_increase)
+ for _ in range(connections):
+ p1, p2 = input().split()
+ uf.merge(person_map[p1], person_map[p2])b
+ print(uf.size_component(person_map[p1]))b
+
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+
int main() {
+
+ int tests;
+ int connections;
+
+ cin >> tests;
+
+ for (int t_no=0; t_no<tests; t_no++) {
+ cin >> connections;
+ int max_people = connections * 2;b
+ UnionFind uf(max_people);b
+ int counter = 0;
+ map<str, int> person_map;
+ for (int c=0; c<connections; c++>) {
+ string c1, c2;
+ cin >> c1 >> c2;
+ if (person_map.count(c1) == 0) person_map[c1] = counter++;
+ if (person_map.count(c2) == 0) person_map[c2] = counter++;
+ uf.merge(person_map[c1], person_map[c2]);b
+ cout << uf.size_component(person_map[c1]) << endl;b
+ }
+ }
+
+ return 0;
+}
+
A slightly hidden application
Next, lets try a harder problem - Roads of NITT.
Have a go!
(Note: The input format is very weird (There’s some whitespace where there shouldn’t be). My current python solution fails for this reason)
Hint
This problem seems similar but different to the problem above. We are still asking about connectivity, but we are breaking connections rather than forming them :(.
Consider this though - Would you be able to solve the problem if it was told in reverse?
Solution
Looking at the problem in reverse, it seems we start off with a disconnected area, and then bit by bit, more connections are made. This is starting to look like Union Find!
So all we need to do is:
- Calculate what roads remain at the end of the problem
- Answer the queries in reverse, joining instead of destroying
- Reverse these results and print them
However, we need to be a bit careful about what our results are in the first place - We want to count how many pairs of hostels are disconnected - This is an N^2 operation. We can do this in N using union find (For every vertex, we know how many vertices it is connected to (and therefore not connected to)), but we still don’t want to do this for every query. Let’s start by calculating the correct value after all roads have been destroyed.
If a road is formed, how many old pairs of hostels are no longer disconnected? A hostel can only be connected now and disconnected before if one of the hostels was already connected to the LHS of the road, and the other hostel was already connected to the RHS of the road. The number of possible pairs here is the size of the component on the LHS of the road, times the size of the component on the RHS of the road.
So every time we see an R
query, we just need to update our current count of disconnect pairs using the Union Find:
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+
t = int(input())
+
+for _ in range(t):
+ n = int(input())
+ edges = []
+ for _ in range(n-1):
+ x, y = list(map(int, input().split()))
+ # 0-index
+ edges.append((x-1, y-1))
+ connected = [False] * len(edges)
+ q = int(input())
+ queries = []
+ for _ in range(q):
+ queries.append(input())
+ if queries[-1].startswith("R"):
+ connected[int(queries[-1].split()[1])-1] = False
+ uf = UnionFind(n)b
+ # Add all remaining roadsb
+ for a in range(len(edges)):b
+ if connected[a]:b
+ uf.merge(edges[a][0], edges[a][1])b
+ # First - calculate how many pairs of hostels are disconnected.
+ current = 0b
+ for x in range(n):b
+ current += n - uf.size_component(x)b
+ current //= 2b
+ # Answering time!
+ queries.reverse()
+ answers = []
+ for q in queries:
+ if q.startswith("Q"):
+ answers.append(current)
+ else:
+ edge_index = int(q.split()[1])-1
+ current -= uf.size_component(edges[edge_index][0]) * uf.size_component(edges[edge_index][1])b
+ uf.merge(*edges[edge_index])b
+ answers.reverse()
+ for a in answers:
+ print(a)
+ # Separate ouput by a space
+ print()
+
1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+
vector<pair<int, int> > edges;
+vector<bool> connected;
+vector<int> queries;
+vector<int> answers;
+
+int main() {
+
+ int tests;
+ cin >> tests;
+
+ for (int t=0; t<tests; t++) {
+ edges.clear();
+ connected.clear();
+ queries.clear();
+ answers.clear();
+ int n;
+ cin >> n;
+ for (int i=0; i<n-1; i++) {
+ int x, y;
+ cin >> x >> y;
+ // 0-index
+ edges.push_back({x-1, y-1});
+ }
+ connected.assign(n-1, true);
+ int q;
+ cin >> q;
+ for (int i=0; i<q; i++) {
+ string s;
+ cin >> s;
+ if (s == "Q") {
+ queries.push_back(-1);
+ } else {
+ int a;
+ cin >> a;
+ queries.push_back(a-1);
+ connected[a-1] = false;
+ }
+ }
+ UnionFind uf(n);
+ // Add all remaining roads
+ for (int i=0; i<n-1; i++) {b
+ if (connected[i]) {b
+ uf.merge(edges[i].first, edges[i].second);b
+ }b
+ }b
+ // First - calculate how many pairs of hostels are disconnected.
+ int current = 0;b
+ for (int i=0; i<n; i++)b
+ current += n - uf.size_component(i);b
+ current = current / 2;b
+ // Answering Time!
+ reverse(queries.begin(), queries.end());
+ for (auto qn: queries) {
+ if (qn == -1) {
+ answers.push_back(current);
+ } else {
+ current = current - uf.size_component(edges[qn].first) * uf.size_component(edges[qn].second);b
+ uf.merge(edges[qn].first, edges[qn].second);b
+ }
+ }
+ reverse(answers.begin(), answers.end());
+ for (auto a: answers) {
+ cout << a << endl;
+ }
+ cout << endl;
+ }
+
+ return 0;
+}
+
Related Problems
This post is licensed under GNU GPL V3 by the author. Comments powered by Disqus.
diff --git a/redirects.json b/redirects.json
new file mode 100644
index 0000000..2f3c5c1
--- /dev/null
+++ b/redirects.json
@@ -0,0 +1 @@
+{"/norobots/":"https://monashaps.github.io//404.html","/assets/":"https://monashaps.github.io//404.html","/posts/":"https://monashaps.github.io//404.html"}
\ No newline at end of file
diff --git a/robots.txt b/robots.txt
new file mode 100644
index 0000000..9a60a0b
--- /dev/null
+++ b/robots.txt
@@ -0,0 +1,5 @@
+User-agent: *
+
+Disallow: /norobots/
+
+Sitemap: https://monashaps.github.io//sitemap.xml
diff --git a/sitemap.xml b/sitemap.xml
new file mode 100644
index 0000000..f091779
--- /dev/null
+++ b/sitemap.xml
@@ -0,0 +1,72 @@
+
+
+
+https://monashaps.github.io//posts/dp/
+2021-12-13T11:01:10+11:00
+
+
+https://monashaps.github.io//posts/mod/
+2021-04-01T22:07:32+11:00
+
+
+https://monashaps.github.io//posts/factorization/
+2021-04-05T21:02:26+10:00
+
+
+https://monashaps.github.io//posts/lca/
+2021-12-13T11:01:10+11:00
+
+
+https://monashaps.github.io//posts/problems-21-s2-c1/
+2021-08-23T11:00:00+10:00
+
+
+https://monashaps.github.io//posts/uf/
+2021-12-27T13:53:13+11:00
+
+
+https://monashaps.github.io//posts/dsless-editorial/
+2023-12-28T23:02:40+11:00
+
+
+https://monashaps.github.io//categories/
+2023-12-28T23:03:47+11:00
+
+
+https://monashaps.github.io//tags/
+2023-12-28T23:03:47+11:00
+
+
+https://monashaps.github.io//archives/
+2023-12-28T23:03:47+11:00
+
+
+https://monashaps.github.io//about/
+2023-12-28T23:03:47+11:00
+
+
+https://monashaps.github.io//
+
+
+https://monashaps.github.io//tags/difficulty-2/
+
+
+https://monashaps.github.io//tags/difficulty-3/
+
+
+https://monashaps.github.io//categories/data-structures/
+
+
+https://monashaps.github.io//categories/math/
+
+
+https://monashaps.github.io//categories/trees/
+
+
+https://monashaps.github.io//categories/contests/
+
+
+https://monashaps.github.io//assets/img/comp_assets/MCPC_Editorial.pdf
+2023-12-28T23:03:29+11:00
+
+
diff --git a/sw.js b/sw.js
new file mode 100644
index 0000000..dc2c062
--- /dev/null
+++ b/sw.js
@@ -0,0 +1 @@
+self.importScripts('/assets/js/data/cache-list.js'); var cacheName = 'chirpy-20231228.2303'; function isExcluded(url) { const regex = /(^http(s)?|^\/)/; /* the regex for CORS url or relative url */ for (const rule of exclude) { if (!regex.test(url) || url.indexOf(rule) != -1) { return true; } } return false; } self.addEventListener('install', (e) => { self.skipWaiting(); e.waitUntil( caches.open(cacheName).then((cache) => { return cache.addAll(include); }) ); }); self.addEventListener('fetch', (e) => { e.respondWith( caches.match(e.request).then((r) => { /* console.log('[Service Worker] Fetching resource: ' + e.request.url); */ return r || fetch(e.request).then((response) => { return caches.open(cacheName).then((cache) => { if (!isExcluded(e.request.url)) { /* console.log('[Service Worker] Caching new resource: ' + e.request.url); */ cache.put(e.request, response.clone()); } return response; }); }); }) ); }); self.addEventListener('activate', (e) => { e.waitUntil( caches.keys().then((keyList) => { return Promise.all(keyList.map((key) => { if(key !== cacheName) { return caches.delete(key); } })); }) ); });
diff --git a/tags/difficulty-2/index.html b/tags/difficulty-2/index.html
new file mode 100644
index 0000000..d9bdb61
--- /dev/null
+++ b/tags/difficulty-2/index.html
@@ -0,0 +1 @@
+ Difficulty 2 | Monash Code Binder Difficulty 2 2
- Union Find / DSU Dec 15, 2021
- Dynamic Programming Mar 26, 2021
diff --git a/tags/difficulty-3/index.html b/tags/difficulty-3/index.html
new file mode 100644
index 0000000..a5607db
--- /dev/null
+++ b/tags/difficulty-3/index.html
@@ -0,0 +1 @@
+ Difficulty 3 | Monash Code Binder Difficulty 3 3
- Least Common Ancestor (LCA) Apr 20, 2021
- Primes and Factorization Techniques Apr 5, 2021
- Modular Arithmetic Mar 29, 2021
diff --git a/tags/index.html b/tags/index.html
new file mode 100644
index 0000000..1ff22e2
--- /dev/null
+++ b/tags/index.html
@@ -0,0 +1 @@
+ Tags | Monash Code Binder Tags