-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmxor_solver.h
139 lines (123 loc) · 3.58 KB
/
mxor_solver.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
/*
* mxor-solver header.
* written by Shuangquan Li, [email protected]
* created on 2016-11-2
*/
#ifndef __MXOR_SOLVER_H__
#define __MXOR_SOLVER_H__
#include <cassert>
#include <memory>
// this class is designed to solve max/min xor problems, such as:
// given an integer array A, find the max/min xor result of any of two intgers in A.
// given an integer array A, find the number of integer pairs whose xor result is
// greater than X(some given integer).
template<typename IntegerType, size_t MAX_BITS> class mxor_solver {
struct Node {
std::shared_ptr<Node> child[2]; // child[i] is child of bit i, i = 0, 1
size_t dup; // duplicates of this node, also the size of subtree
Node() : dup(0) {}
};
std::shared_ptr<Node> head;
public:
mxor_solver() : head(new Node) {}
// number of elements exist now
size_t size() const { return head->dup; }
// if number of elements equals to 0
bool empty() const { return head->dup == 0; }
// erase all elements
void clear() { head.reset(new Node); }
// insert a new element `val`
void insert(const IntegerType& val) {
std::shared_ptr<Node> cur = head;
++(cur->dup);
for (int i = MAX_BITS - 1; i >= 0; --i) {
int B = val >> i & 1;
if (cur->child[B])
cur = cur->child[B];
else
cur = cur->child[B] = std::shared_ptr<Node>(new Node());
++(cur->dup);
}
}
// erase once occurrence of element whose value is `val`
void erase(const IntegerType& val) {
assert(count(val));
std::shared_ptr<Node> cur = head;
--(cur->dup);
for (int i = MAX_BITS - 1; i >= 0; --i) {
int B = val >> i & 1;
--(cur->child[B]->dup);
if (cur->child[B]->dup == 0) {
cur->child[B].reset();
return;
}
cur = cur->child[B];
}
}
// count time of occurrence of element whose value is `val`
size_t count(const IntegerType& val) const {
std::shared_ptr<Node> cur = head;
for (int i = MAX_BITS - 1; i >= 0; --i) {
int B = val >> i & 1;
if (!(cur->child[B])) return 0;
cur = cur->child[B];
}
return cur->dup;
}
// query max xor result of all elements with `xorWithWhom`
IntegerType query_max(const IntegerType& xorWithWhom = 0) const {
assert(size());
IntegerType ret = 0;
std::shared_ptr<Node> cur = head;
for (int i = MAX_BITS - 1; i >= 0; --i) {
int B = xorWithWhom >> i & 1;
if (cur->child[B ^ 1]) {
ret |= IntegerType(1) << i;
cur = cur->child[B ^ 1];
}
else {
cur = cur->child[B];
}
}
return ret;
}
// query min xor result of all elements with `xorWithWhom`
IntegerType query_min(const IntegerType& xorWithWhom = 0) const {
assert(size());
IntegerType ret = 0;
std::shared_ptr<Node> cur = head;
for (int i = MAX_BITS - 1; i >= 0; --i) {
int B = xorWithWhom >> i & 1;
if (cur->child[B]) {
cur = cur->child[B];
}
else {
ret |= IntegerType(1) << i;
cur = cur->child[B ^ 1];
}
}
return ret;
}
// query the number of elements whose xor result with `xorWithWhom` is greater than `thanWhom`
size_t query_gt(const IntegerType& xorWithWhom, const IntegerType& thanWhom) const {
if (empty()) return 0;
size_t ret = 0;
std::shared_ptr<Node> cur = head;
for (int i = MAX_BITS - 1; i >= 0; --i) {
int W = xorWithWhom >> i & 1;
int T = thanWhom >> i & 1;
if (T) {
if (cur->child[W ^ 1]) cur = cur->child[W ^ 1];
else break;
}
else {
if (cur->child[W ^ 1]) ret += cur->child[W ^ 1]->dup;
if (cur->child[W]) cur = cur->child[W];
else break;
}
}
return ret;
}
};
/* eof */
#endif