Skip to content

Commit 495cff8

Browse files
fix: added path compression to UnionFind (#734)
* fix: implemented path compression in Find & removed unnecessary return value of Union * feat: added a few test cases * fix: modified kruskal implementation to conform to the updated Union method * fix: changed to pointer receivers --------- Co-authored-by: Rak Laptudirm <[email protected]>
1 parent 67eebcb commit 495cff8

File tree

3 files changed

+41
-29
lines changed

3 files changed

+41
-29
lines changed

graph/kruskal.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ func KruskalMST(n int, edges []Edge) ([]Edge, int) {
4242
// Add the weight of the edge to the total cost
4343
cost += edge.Weight
4444
// Merge the sets containing the start and end vertices of the current edge
45-
u = u.Union(int(edge.Start), int(edge.End))
45+
u.Union(int(edge.Start), int(edge.End))
4646
}
4747
}
4848

graph/unionfind.go

+24-21
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
// is used to efficiently maintain connected components in a graph that undergoes dynamic changes,
44
// such as edges being added or removed over time
55
// Worst Case Time Complexity: The time complexity of find operation is nearly constant or
6-
//O(α(n)), where where α(n) is the inverse Ackermann function
6+
//O(α(n)), where α(n) is the inverse Ackermann function
77
// practically, this is a very slowly growing function making the time complexity for find
88
//operation nearly constant.
99
// The time complexity of the union operation is also nearly constant or O(α(n))
1010
// Worst Case Space Complexity: O(n), where n is the number of nodes or element in the structure
1111
// Reference: https://www.scaler.com/topics/data-structures/disjoint-set/
12+
// https://en.wikipedia.org/wiki/Disjoint-set_data_structure
1213
// Author: Mugdha Behere[https://github.com/MugdhaBehere]
1314
// see: unionfind.go, unionfind_test.go
1415

@@ -17,43 +18,45 @@ package graph
1718
// Defining the union-find data structure
1819
type UnionFind struct {
1920
parent []int
20-
size []int
21+
rank []int
2122
}
2223

2324
// Initialise a new union find data structure with s nodes
2425
func NewUnionFind(s int) UnionFind {
2526
parent := make([]int, s)
26-
size := make([]int, s)
27-
for k := 0; k < s; k++ {
28-
parent[k] = k
29-
size[k] = 1
27+
rank := make([]int, s)
28+
for i := 0; i < s; i++ {
29+
parent[i] = i
30+
rank[i] = 1
3031
}
31-
return UnionFind{parent, size}
32+
return UnionFind{parent, rank}
3233
}
3334

34-
// to find the root of the set to which the given element belongs, the Find function serves the purpose
35-
func (u UnionFind) Find(q int) int {
36-
for q != u.parent[q] {
37-
q = u.parent[q]
35+
// Find finds the root of the set to which the given element belongs.
36+
// It performs path compression to make future Find operations faster.
37+
func (u *UnionFind) Find(q int) int {
38+
if q != u.parent[q] {
39+
u.parent[q] = u.Find(u.parent[q])
3840
}
39-
return q
41+
return u.parent[q]
4042
}
4143

42-
// to merge two sets to which the given elements belong, the Union function serves the purpose
43-
func (u UnionFind) Union(a, b int) UnionFind {
44-
rootP := u.Find(a)
45-
rootQ := u.Find(b)
44+
// Union merges the sets, if not already merged, to which the given elements belong.
45+
// It performs union by rank to keep the tree as flat as possible.
46+
func (u *UnionFind) Union(p, q int) {
47+
rootP := u.Find(p)
48+
rootQ := u.Find(q)
4649

4750
if rootP == rootQ {
48-
return u
51+
return
4952
}
5053

51-
if u.size[rootP] < u.size[rootQ] {
54+
if u.rank[rootP] < u.rank[rootQ] {
5255
u.parent[rootP] = rootQ
53-
u.size[rootQ] += u.size[rootP]
56+
} else if u.rank[rootP] > u.rank[rootQ] {
57+
u.parent[rootQ] = rootP
5458
} else {
5559
u.parent[rootQ] = rootP
56-
u.size[rootP] += u.size[rootQ]
60+
u.rank[rootP]++
5761
}
58-
return u
5962
}

graph/unionfind_test.go

+16-7
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ func TestUnionFind(t *testing.T) {
88
u := NewUnionFind(10) // Creating a Union-Find data structure with 10 elements
99

1010
//union operations
11-
u = u.Union(0, 1)
12-
u = u.Union(2, 3)
13-
u = u.Union(4, 5)
14-
u = u.Union(6, 7)
11+
u.Union(0, 1)
12+
u.Union(2, 3)
13+
u.Union(4, 5)
14+
u.Union(6, 7)
1515

1616
// Testing the parent of specific elements
1717
t.Run("Test Find", func(t *testing.T) {
@@ -20,12 +20,21 @@ func TestUnionFind(t *testing.T) {
2020
}
2121
})
2222

23-
u = u.Union(1, 5) // Additional union operation
24-
u = u.Union(3, 7) // Additional union operation
23+
u.Union(1, 5) // Additional union operation
24+
u.Union(3, 7) // Additional union operation
2525

2626
// Testing the parent of specific elements after more union operations
2727
t.Run("Test Find after Union", func(t *testing.T) {
28-
if u.Find(0) != u.Find(5) || u.Find(2) != u.Find(7) {
28+
if u.Find(0) != u.Find(5) || u.Find(1) != u.Find(4) || u.Find(2) != u.Find(7) || u.Find(3) != u.Find(6) {
29+
t.Error("Union operation not functioning correctly")
30+
}
31+
})
32+
33+
u.Union(3, 7) // Repeated union operation
34+
35+
// Testing that repeated union operations are idempotent
36+
t.Run("Test Find after repeated Union", func(t *testing.T) {
37+
if u.Find(2) != u.Find(6) || u.Find(2) != u.Find(7) || u.Find(3) != u.Find(6) || u.Find(3) != u.Find(7) {
2938
t.Error("Union operation not functioning correctly")
3039
}
3140
})

0 commit comments

Comments
 (0)