-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkittys-calculations-on-a-tree.cpp
111 lines (89 loc) · 2.81 KB
/
kittys-calculations-on-a-tree.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#include <iostream>
#include <vector>
#include <algorithm>
#include <utility>
using namespace std;
using Graph = vector<vector<int>>; // Alias for the graph representation
const long mod = 1000000007; // Modulo value
// Class to record and manipulate values for each node
class Record {
long s, su, sudu; // s: sum, su: sum of u, sudu: sum of u*v
public:
Record(): s(0), su(0), sudu(0) { } // Constructor
Record(int id): s(0), su(id), sudu(0) { } // Constructor with initial value
int get_s() const { return s; } // Getter for sum value
// Function to add values of another Record
inline void add(Record &b) {
long ns = (
s + b.s
+ sudu * b.su
+ su * (b.sudu + b.su)
);
if (ns >= mod) ns %= mod;
long nsu = su + b.su;
if (nsu >= mod) nsu %= mod;
long nsudu = sudu + b.sudu + b.su;
if (nsudu >= mod) nsudu %= mod;
s = ns;
su = nsu;
sudu = nsudu;
}
// Function to check if the record is zero
inline bool zero() {
return s == 0 && su == 0 && sudu == 0;
}
};
// Class to generate and execute the code
class Code {
vector<pair<int, int>> code; // Code representation
// Recursive function to generate the code
void gen_code(Graph &g, vector<bool> &color, int i) {
color[i] = true;
for (int child : g[i]) {
if (color[child]) continue;
gen_code(g, color, child);
code.push_back({ child, i });
}
}
public:
// Constructor to generate the code from the graph
Code(Graph &g) {
vector<bool> color(g.size());
gen_code(g, color, 0); // Start with the root node (0)
}
// Function to execute the code and return the result
int exec(vector<Record> &data) {
int last = 0;
for (auto c : code) {
last = c.second;
if (!data[c.first].zero())
data[c.second].add(data[c.first]);
}
return data[last].get_s(); // Return the sum from the last node
}
};
int main() {
int n, q;
cin >> n >> q;
Graph g(n); // Graph representation
for (int i = 0; i < n - 1; i++) {
int a, b;
cin >> a >> b;
g[a - 1].push_back(b - 1); // Edge between nodes a and b
g[b - 1].push_back(a - 1); // Edge between nodes b and a
}
Code code(g); // Generate code from the graph
vector<Record> data(n); // Data for each node
while (q--) {
fill(data.begin(), data.end(), Record()); // Reset data for each query
int m;
cin >> m;
while (m--) {
int a;
cin >> a;
data[a - 1] = Record(a); // Initialize data for node a
}
cout << code.exec(data) << endl; // Execute the code and print the result
}
return 0;
}