Skip to content

Commit b71b387

Browse files
Merge pull request #9761 from IoannisPanagiotas/pcst-initial-wrapup
Pcst initial wrapup
2 parents c18af37 + 2fc8e4e commit b71b387

File tree

12 files changed

+187
-78
lines changed

12 files changed

+187
-78
lines changed

algo/src/main/java/org/neo4j/gds/pricesteiner/ClusterActivity.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ long numberOfActiveClusters() {
6161
return relevantTime.get(clusterId);
6262
}
6363

64-
6564
LongPredicate active() {
6665
return activeClusters::get;
6766
}

algo/src/main/java/org/neo4j/gds/pricesteiner/ClusterEventsPriorityQueue.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ void pop(){
4848
queue.pop();
4949
}
5050

51-
void add(long cluster, double sumOfPrizes){
52-
queue.add(cluster, sumOfPrizes);
51+
void add(long cluster, double remainingMoat){
52+
queue.add(cluster, remainingMoat);
5353
}
5454

5555
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.pricesteiner;
21+
22+
class ClusterMoatPair {
23+
24+
private long cluster;
25+
private double totalMoat;
26+
27+
void assign(long cluster, double totalMoat){
28+
this.cluster = cluster;
29+
this.totalMoat = totalMoat;
30+
}
31+
32+
long cluster(){
33+
return cluster;
34+
}
35+
36+
double totalMoat(){
37+
return totalMoat;
38+
}
39+
}

algo/src/main/java/org/neo4j/gds/pricesteiner/ClusterStructure.java

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ void setClusterPrize(long clusterId, double prize){
8181
initialMoatLeft.set(clusterId, prize);
8282
}
8383

84-
double clusterPrize(long clusterId){
84+
double moatLeft(long clusterId){
8585
return initialMoatLeft.get(clusterId);
8686
}
8787

@@ -90,16 +90,16 @@ void setClusterPrize(long clusterId, double prize){
9090
return currentMoat + slack;
9191
}
9292

93-
ClusterMoatPair sumOnEdgePart(long node, double currentMoat){
93+
void sumOnEdgePart(long node, double currentMoat, ClusterMoatPair clusterMoatPair){
9494
double sum = 0;
95-
long currentNode =node;
95+
long currentNode = node;
9696

9797
while (true){
9898

9999
var parentNode = parent.get(currentNode);
100100
double currentValue = moatAt(currentNode,currentMoat);
101101

102-
sum+= currentValue;
102+
sum+= currentValue;
103103
if (parentNode== -1){
104104
break;
105105
}else{
@@ -113,24 +113,21 @@ ClusterMoatPair sumOnEdgePart(long node, double currentMoat){
113113
}
114114
sum += skippedParentSum.get(currentNode);
115115
currentNode = nextParent;
116-
117-
118116
}
119117

120118
}
121-
122-
return new ClusterMoatPair(currentNode,sum);
119+
clusterMoatPair.assign(currentNode,sum);
123120
}
124121

125122
BitSet activeOriginalNodesOfCluster(long clusterId){
126-
BitSet bitSet=new BitSet(originalNodeCount);
123+
BitSet bitSet = new BitSet(originalNodeCount);
127124

128125
if (clusterId < originalNodeCount){
129126
bitSet.set(clusterId);
130127
return bitSet;
131128
}
132129

133-
HugeLongArrayStack stack= HugeLongArrayStack.newStack(originalNodeCount);
130+
HugeLongArrayStack stack = HugeLongArrayStack.newStack(originalNodeCount);
134131
stack.push(clusterId);
135132

136133
while (!stack.isEmpty()){
@@ -183,7 +180,6 @@ long singleActiveCluster(){
183180

184181

185182
}
186-
record ClusterMoatPair(long cluster, double totalMoat){}
187183

188184

189185

algo/src/main/java/org/neo4j/gds/pricesteiner/EdgeEventsQueue.java

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,24 +27,21 @@ class EdgeEventsQueue {
2727

2828
private final HugeObjectArray<PairingHeap> pairingHeaps;
2929
private final HugeLongPriorityQueue edgeEventsPriorityQueue;
30-
private final ObjectArrayList<PairingHeapElement> helpingArray;
31-
private long currentlyActive;
3230

3331
EdgeEventsQueue(long nodeCount){
3432

3533
this.pairingHeaps= HugeObjectArray.newArray(PairingHeap.class, 2*nodeCount);
3634
this.edgeEventsPriorityQueue = HugeLongPriorityQueue.min(2*nodeCount);
37-
this.helpingArray= new ObjectArrayList<>(4096);
35+
ObjectArrayList<PairingHeapElement> helpingArray = new ObjectArrayList<>(4096);
36+
3837
for (int i=0;i<nodeCount;i++){
3938
pairingHeaps.set(i, new PairingHeap(helpingArray));
4039
}
41-
currentlyActive = nodeCount;
42-
}
43-
long currentlyActive(){
44-
return currentlyActive;
40+
4541
}
42+
4643
double nextEventTime(){
47-
return edgeEventsPriorityQueue.cost(edgeEventsPriorityQueue.top());
44+
return edgeEventsPriorityQueue.cost(edgeEventsPriorityQueue.top());
4845
}
4946

5047
long top(){
@@ -80,6 +77,7 @@ void addWithCheck(long s, long edgePart, double w){
8077
edgeEventsPriorityQueue.set(s, w);
8178
}
8279
}
80+
8381
void addWithoutCheck(long s, long edgePart, double w){
8482
var pairingHeapOfs = pairingHeaps.get(s);
8583
pairingHeapOfs.add(edgePart,w);
@@ -97,13 +95,11 @@ void mergeAndUpdate(long newCluster, long cluster1,long cluster2){
9795
deactivateCluster(cluster2);
9896

9997
edgeEventsPriorityQueue.add(newCluster, pairingHeaps.get(newCluster).minValue());
100-
currentlyActive--;
10198
}
10299

103100
void deactivateCluster(long clusterId){
104101
// very-very bad way of removing from common heap-of-heaps
105102
edgeEventsPriorityQueue.set(clusterId,Double.MAX_VALUE); //ditto
106-
107103
}
108104

109105
void performInitialAssignment(long nodeCount){

algo/src/main/java/org/neo4j/gds/pricesteiner/GrowthPhase.java

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ GrowthResult grow() {
7070

7171
progressTracker.beginSubTask("Growing");
7272
double moat;
73+
74+
ClusterMoatPair clusterMoatPairOfu = new ClusterMoatPair();
75+
ClusterMoatPair clusterMoatPairOfv = new ClusterMoatPair();
76+
7377
while (clusterStructure.numberOfActiveClusters() > 1) {
7478
terminationFlag.assertRunning();
7579
double edgeEventTime = edgeEventsQueue.nextEventTime();
@@ -92,14 +96,14 @@ GrowthResult grow() {
9296
continue;
9397
}
9498

95-
ClusterMoatPair cmu = clusterStructure.sumOnEdgePart(u,moat);
96-
ClusterMoatPair cmv = clusterStructure.sumOnEdgePart(v,moat);
99+
clusterStructure.sumOnEdgePart(u,moat,clusterMoatPairOfu);
100+
clusterStructure.sumOnEdgePart(v,moat,clusterMoatPairOfv);
97101

98-
var uCluster = cmu.cluster();
99-
var uClusterSum = cmu.totalMoat();
102+
var uCluster = clusterMoatPairOfu.cluster();
103+
var uClusterSum = clusterMoatPairOfu.totalMoat();
100104

101-
var vCluster = cmv.cluster();
102-
var vClusterSum = cmv.totalMoat();
105+
var vCluster = clusterMoatPairOfv.cluster();
106+
var vClusterSum = clusterMoatPairOfv.totalMoat();
103107

104108

105109
if (vCluster == uCluster) {
@@ -160,9 +164,10 @@ private void initializeEdgeParts() {
160164
edgeCosts.set(edgeId, w);
161165
edgeParts.set(2 * edgeId, s);
162166
edgeParts.set(2 * edgeId + 1, t);
163-
edgeEventsQueue.addBothWays(s, t, edgePart1, edgePart2, w / 2);
167+
edgeEventsQueue.addBothWays(s, t, edgePart1, edgePart2, w / 2.0);
168+
return true;
164169
}
165-
return s > t;
170+
return false;
166171
});
167172
progressTracker.logProgress(graph.degree(u));
168173
}
@@ -173,7 +178,7 @@ private void initializeClusterPrizes() {
173178
for (long u = 0; u < graph.nodeCount(); ++u) {
174179
double prize = prizes.applyAsDouble(u);
175180
clusterStructure.setClusterPrize(u, prize);
176-
clusterEventsPriorityQueue.add(u, clusterStructure.tightnessTime(u, 0));
181+
clusterEventsPriorityQueue.add(u, prize);
177182
}
178183
}
179184

@@ -193,13 +198,13 @@ private void mergeClusters(
193198
edgeEventsQueue.increaseValuesOnInactiveCluster(cluster2, moat - clusterStructure.inactiveSince(cluster2));
194199
}
195200

196-
var newCluster = clusterStructure.merge(cluster1, cluster2,moat);
201+
var newCluster = clusterStructure.merge(cluster1, cluster2, moat);
197202

198203
edgeEventsQueue.mergeAndUpdate(newCluster, cluster1, cluster2);
199204
clusterEventsPriorityQueue.add(newCluster, clusterStructure.tightnessTime(newCluster, moat));
200205

201206
addToTree(edgeId);
202-
edgeParts.set(2*edgeId,-edgeParts.get(2*edgeId));
207+
edgeParts.set(2*edgeId, -edgeParts.get(2*edgeId)); //signal that edge id has been used
203208
edgeParts.set(2*edgeId+1,-edgeParts.get(2*edgeId+1));
204209
}
205210

algo/src/main/java/org/neo4j/gds/pricesteiner/StrongPruning.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ void performPruning(){
7979
}
8080
}
8181

82-
8382
while (currentPos < totalPos) {
8483
terminationFlag.assertRunning();;
8584
var nextLeaf = queue.get(currentPos++);
@@ -145,8 +144,8 @@ PrizeSteinerTreeResult resultTree(){
145144
private void pruneSubtree(long node, HugeLongArray helpingArray,HugeLongArray parents){
146145
terminationFlag.assertRunning();
147146
var tree = treeStructure.tree();
148-
long currentPosition= 0;
149-
MutableLong position=new MutableLong();
147+
long currentPosition = 0;
148+
MutableLong position = new MutableLong();
150149
helpingArray.set(position.getAndIncrement(),node);
151150

152151
while (currentPosition < position.get()){

algo/src/main/java/org/neo4j/gds/pricesteiner/TreeProducer.java

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import org.neo4j.gds.api.Graph;
2525
import org.neo4j.gds.api.IdMap;
2626
import org.neo4j.gds.collections.ha.HugeLongArray;
27-
import org.neo4j.gds.core.Aggregation;
2827
import org.neo4j.gds.core.loading.construction.GraphFactory;
2928
import org.neo4j.gds.core.loading.construction.RelationshipsBuilder;
3029
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
@@ -49,11 +48,10 @@ static TreeStructure createTree(GrowthResult growthResult,long nodeCount, IdMap
4948

5049
RelationshipsBuilder relationshipsBuilder = GraphFactory.initRelationshipsBuilder()
5150
.nodes(idMap)
52-
.relationshipType(RelationshipType.of("_IGNORED_"))
51+
.relationshipType(RelationshipType.of("TREE"))
5352
.orientation(Orientation.UNDIRECTED)
5453
.addPropertyConfig(GraphFactory.PropertyConfig.builder()
55-
.propertyKey("property")
56-
.aggregation(Aggregation.SUM)
54+
.propertyKey("weight")
5755
.build())
5856
.build();
5957

@@ -71,11 +69,9 @@ static TreeStructure createTree(GrowthResult growthResult,long nodeCount, IdMap
7169
var singleTypeRelationships= relationshipsBuilder.build();
7270
var tree = GraphFactory.create(idMap, singleTypeRelationships);
7371

74-
7572
progressTracker.endSubTask("Tree Creation");
7673
return new TreeStructure(tree,degree, idMap.nodeCount());
7774

78-
7975
}
8076

8177
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.pricesteiner;
21+
22+
import org.junit.jupiter.api.Test;
23+
24+
import static org.assertj.core.api.Assertions.assertThat;
25+
26+
class ClusterActivityTest {
27+
28+
@Test
29+
void shouldWorkAsExpected(){
30+
var clusterActivity =new ClusterActivity(10);
31+
assertThat(clusterActivity.relevantTime(0)).isEqualTo(0);
32+
assertThat(clusterActivity.numberOfActiveClusters()).isEqualTo(10);
33+
34+
clusterActivity.deactivateCluster(0,5);
35+
assertThat(clusterActivity.relevantTime(0)).isEqualTo(5);
36+
assertThat(clusterActivity.numberOfActiveClusters()).isEqualTo(9);
37+
38+
clusterActivity.deactivateCluster(1,5);
39+
assertThat(clusterActivity.numberOfActiveClusters()).isEqualTo(8);
40+
41+
clusterActivity.activateCluster(11,5);
42+
assertThat(clusterActivity.relevantTime(11)).isEqualTo(5);
43+
assertThat(clusterActivity.numberOfActiveClusters()).isEqualTo(9);
44+
45+
assertThat(clusterActivity.active(0)).isFalse();
46+
assertThat(clusterActivity.active(1)).isFalse();
47+
assertThat(clusterActivity.active(2)).isTrue();
48+
assertThat(clusterActivity.active(11)).isTrue();
49+
50+
clusterActivity.deactivateCluster(11,15);
51+
assertThat(clusterActivity.numberOfActiveClusters()).isEqualTo(8);
52+
53+
assertThat(clusterActivity.relevantTime(11)).isEqualTo(15);
54+
}
55+
56+
@Test
57+
void shouldFindSingleActiveNodeCorrectly(){
58+
var clusterActivity =new ClusterActivity(4);
59+
clusterActivity.deactivateCluster(0,1);
60+
clusterActivity.deactivateCluster(1,1);
61+
clusterActivity.deactivateCluster(3,1);
62+
63+
assertThat(clusterActivity.firstActiveCluster()).isEqualTo(2);
64+
clusterActivity.activateCluster(0,1); //this does not happen in practice but oh well it is a test
65+
clusterActivity.deactivateCluster(2,1);
66+
assertThat(clusterActivity.firstActiveCluster()).isEqualTo(0);
67+
68+
}
69+
70+
}

0 commit comments

Comments
 (0)