Skip to content

Commit 2c19e8c

Browse files
authored
Termination flag for sampling algorithms - abstract NodeSampler (#9724)
Support termination flag in sampling algorithms
1 parent f4f3d01 commit 2c19e8c

File tree

17 files changed

+131
-25
lines changed

17 files changed

+131
-25
lines changed

applications/graph-store-catalog/src/main/java/org/neo4j/gds/applications/graphstorecatalog/DefaultGraphCatalogApplications.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -963,6 +963,7 @@ public RandomWalkSamplingResult sampleRandomWalkWithRestarts(
963963
DatabaseId databaseId,
964964
TaskRegistryFactory taskRegistryFactory,
965965
UserLogRegistryFactory userLogRegistryFactory,
966+
TerminationFlag terminationFlag,
966967
String graphName,
967968
String originGraphName,
968969
Map<String, Object> configuration
@@ -972,6 +973,7 @@ public RandomWalkSamplingResult sampleRandomWalkWithRestarts(
972973
databaseId,
973974
taskRegistryFactory,
974975
userLogRegistryFactory,
976+
terminationFlag,
975977
graphName,
976978
originGraphName,
977979
configuration,
@@ -985,6 +987,7 @@ public RandomWalkSamplingResult sampleCommonNeighbourAwareRandomWalk(
985987
DatabaseId databaseId,
986988
TaskRegistryFactory taskRegistryFactory,
987989
UserLogRegistryFactory userLogRegistryFactory,
990+
TerminationFlag terminationFlag,
988991
String graphNameAsString,
989992
String originGraphName,
990993
Map<String, Object> configuration
@@ -994,6 +997,7 @@ public RandomWalkSamplingResult sampleCommonNeighbourAwareRandomWalk(
994997
databaseId,
995998
taskRegistryFactory,
996999
userLogRegistryFactory,
1000+
terminationFlag,
9971001
graphNameAsString,
9981002
originGraphName,
9991003
configuration,
@@ -1104,6 +1108,7 @@ private RandomWalkSamplingResult sampleRandomWalk(
11041108
DatabaseId databaseId,
11051109
TaskRegistryFactory taskRegistryFactory,
11061110
UserLogRegistryFactory userLogRegistryFactory,
1111+
TerminationFlag terminationFlag,
11071112
String graphNameAsString,
11081113
String originGraphNameAsString,
11091114
Map<String, Object> configuration,
@@ -1125,6 +1130,7 @@ private RandomWalkSamplingResult sampleRandomWalk(
11251130
userLogRegistryFactory,
11261131
graphStore,
11271132
graphProjectConfig,
1133+
terminationFlag,
11281134
originGraphName,
11291135
graphName,
11301136
configuration,

applications/graph-store-catalog/src/main/java/org/neo4j/gds/applications/graphstorecatalog/GraphCatalogApplications.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ RandomWalkSamplingResult sampleRandomWalkWithRestarts(
255255
DatabaseId databaseId,
256256
TaskRegistryFactory taskRegistryFactory,
257257
UserLogRegistryFactory userLogRegistryFactory,
258+
TerminationFlag terminationFlag,
258259
String graphName,
259260
String originGraphName,
260261
Map<String, Object> configuration
@@ -265,6 +266,7 @@ RandomWalkSamplingResult sampleCommonNeighbourAwareRandomWalk(
265266
DatabaseId databaseId,
266267
TaskRegistryFactory taskRegistryFactory,
267268
UserLogRegistryFactory userLogRegistryFactory,
269+
TerminationFlag terminationFlag,
268270
String graphName,
269271
String originGraphName,
270272
Map<String, Object> configuration

applications/graph-store-catalog/src/main/java/org/neo4j/gds/applications/graphstorecatalog/GraphSamplingApplication.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.neo4j.gds.graphsampling.GraphSampleConstructor;
3333
import org.neo4j.gds.graphsampling.RandomWalkSamplerType;
3434
import org.neo4j.gds.logging.Log;
35+
import org.neo4j.gds.termination.TerminationFlag;
3536

3637
import java.util.Map;
3738

@@ -50,6 +51,7 @@ RandomWalkSamplingResult sample(
5051
UserLogRegistryFactory userLogRegistryFactory,
5152
GraphStore graphStore,
5253
GraphProjectConfig graphProjectConfig,
54+
TerminationFlag terminationFlag,
5355
GraphName originGraphName,
5456
GraphName graphName,
5557
Map<String, Object> configuration,
@@ -73,7 +75,8 @@ RandomWalkSamplingResult sample(
7375
samplerConfig,
7476
graphStore,
7577
samplerAlgorithm,
76-
progressTracker
78+
progressTracker,
79+
terminationFlag
7780
);
7881
var sampledGraphStore = graphSampleConstructor.compute();
7982

applications/graph-store-catalog/src/test/java/org/neo4j/gds/applications/graphstorecatalog/GraphSamplingApplicationTest.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import org.neo4j.gds.extension.IdFunction;
3939
import org.neo4j.gds.extension.Inject;
4040
import org.neo4j.gds.logging.Log;
41+
import org.neo4j.gds.termination.TerminationFlag;
4142

4243
import java.util.List;
4344
import java.util.Map;
@@ -118,6 +119,7 @@ void shouldSampleRWR(Map<String, Object> mapConfiguration, long expectedNodeCoun
118119
EmptyUserLogRegistryFactory.INSTANCE,
119120
graphStore,
120121
GraphProjectConfig.emptyWithName("user", "graph"),
122+
TerminationFlag.RUNNING_TRUE,
121123
GraphName.parse("graph"),
122124
GraphName.parse("sample"),
123125
mapConfiguration,
@@ -157,6 +159,7 @@ void shouldSampleCNARW(Map<String, Object> mapConfiguration, long expectedNodeCo
157159
EmptyUserLogRegistryFactory.INSTANCE,
158160
graphStore,
159161
GraphProjectConfig.emptyWithName("user", "graph"),
162+
TerminationFlag.RUNNING_TRUE,
160163
GraphName.parse("graph"),
161164
GraphName.parse("sample"),
162165
mapConfiguration,
@@ -196,6 +199,7 @@ void shouldUseSingleStartNodeRWR(double samplingRatio, long expectedStartNodeCou
196199
EmptyUserLogRegistryFactory.INSTANCE,
197200
graphStore,
198201
GraphProjectConfig.emptyWithName("user", "graph"),
202+
TerminationFlag.RUNNING_TRUE,
199203
GraphName.parse("graph"),
200204
GraphName.parse("sample"),
201205
Map.of(
@@ -239,6 +243,7 @@ void shouldUseSingleStartNodeCNARW(double samplingRatio, long expectedStartNodeC
239243
EmptyUserLogRegistryFactory.INSTANCE,
240244
graphStore,
241245
GraphProjectConfig.emptyWithName("user", "graph"),
246+
TerminationFlag.RUNNING_TRUE,
242247
GraphName.parse("graph"),
243248
GraphName.parse("sample"),
244249
Map.of(

graph-sampling/src/main/java/org/neo4j/gds/graphsampling/GraphSampleConstructor.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
4646
import org.neo4j.gds.core.utils.progress.tasks.Task;
4747
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
48+
import org.neo4j.gds.termination.TerminationFlag;
4849

4950
import java.util.List;
5051
import java.util.Map;
@@ -57,18 +58,21 @@ public class GraphSampleConstructor {
5758
private final GraphStore inputGraphStore;
5859
private final NodesSampler nodesSampler;
5960
private final ProgressTracker progressTracker;
61+
private final TerminationFlag terminationFlag;
6062

6163
public GraphSampleConstructor(
6264
GraphSampleAlgoConfig config,
6365
GraphStore inputGraphStore,
6466
NodesSampler nodesSampler,
65-
ProgressTracker progressTracker
67+
ProgressTracker progressTracker,
68+
TerminationFlag terminationFlag
6669
) {
6770
this.config = config;
6871
this.concurrency = config.concurrency();
6972
this.inputGraphStore = inputGraphStore;
7073
this.nodesSampler = nodesSampler;
7174
this.progressTracker = progressTracker;
75+
this.terminationFlag = terminationFlag;
7276
}
7377

7478
public GraphStore compute() {
@@ -79,6 +83,8 @@ public GraphStore compute() {
7983
config.internalRelationshipTypes(inputGraphStore),
8084
config.relationshipWeightProperty()
8185
);
86+
nodesSampler.setTerminationFlag(terminationFlag);
87+
8288
var sampledNodesBitSet = nodesSampler.compute(inputGraph, progressTracker);
8389

8490
progressTracker.beginSubTask("Construct graph");

graph-sampling/src/main/java/org/neo4j/gds/graphsampling/NodesSampler.java

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,25 @@
2424
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;
2525
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2626
import org.neo4j.gds.core.utils.progress.tasks.Task;
27+
import org.neo4j.gds.termination.TerminationFlag;
2728

28-
public interface NodesSampler {
29-
HugeAtomicBitSet compute(Graph inputGraph, ProgressTracker progressTracker);
29+
public abstract class NodesSampler {
30+
protected abstract HugeAtomicBitSet compute(
31+
Graph inputGraph,
32+
ProgressTracker progressTracker
33+
);
3034

31-
Task progressTask(GraphStore graphStore);
35+
protected abstract Task progressTask(GraphStore graphStore);
3236

33-
String progressTaskName();
37+
protected abstract String progressTaskName();
38+
39+
protected volatile TerminationFlag terminationFlag = TerminationFlag.RUNNING_TRUE;
40+
41+
public void setTerminationFlag(TerminationFlag terminationFlag) {
42+
this.terminationFlag = terminationFlag;
43+
}
44+
45+
public TerminationFlag getTerminationFlag() {
46+
return terminationFlag;
47+
}
3448
}

graph-sampling/src/main/java/org/neo4j/gds/graphsampling/RandomWalkBasedNodesSampler.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
*/
2020
package org.neo4j.gds.graphsampling;
2121

22-
public interface RandomWalkBasedNodesSampler extends NodesSampler {
22+
public abstract class RandomWalkBasedNodesSampler extends NodesSampler {
2323

24-
long startNodesCount();
24+
public abstract long startNodesCount();
2525

2626
}

graph-sampling/src/main/java/org/neo4j/gds/graphsampling/samplers/NodeLabelHistogram.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
3131
import org.neo4j.gds.core.utils.partition.PartitionUtils;
3232
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
33+
import org.neo4j.gds.termination.TerminationFlag;
3334

3435
import java.util.Comparator;
3536
import java.util.Optional;
@@ -45,7 +46,9 @@ interface Result {
4546
LongLongHashMap histogram();
4647
}
4748

48-
public static Result compute(Graph inputGraph, Concurrency concurrency, ProgressTracker progressTracker) {
49+
public static Result compute(Graph inputGraph, Concurrency concurrency, ProgressTracker progressTracker,
50+
TerminationFlag terminationFlag
51+
) {
4952
progressTracker.beginSubTask("Count node labels");
5053
progressTracker.setSteps(inputGraph.nodeCount());
5154

@@ -61,6 +64,7 @@ public static Result compute(Graph inputGraph, Concurrency concurrency, Progress
6164
concurrency,
6265
inputGraph.nodeCount(),
6366
partition -> (Runnable) () -> {
67+
terminationFlag.assertRunning();
6468
var labelCount = new LongLongHashMap();
6569
partition.consume(nodeId -> {
6670
labelCount.addTo(

graph-sampling/src/main/java/org/neo4j/gds/graphsampling/samplers/SeenNodes.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.neo4j.gds.core.concurrency.Concurrency;
2727
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;
2828
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
29+
import org.neo4j.gds.termination.TerminationFlag;
2930

3031
import java.util.Arrays;
3132

@@ -40,14 +41,16 @@ public interface SeenNodes {
4041
long totalExpectedNodes();
4142

4243
static SeenNodes create(
43-
Graph inputGraph, ProgressTracker progressTracker, boolean nodeLabelStratification,
44+
Graph inputGraph, ProgressTracker progressTracker, TerminationFlag terminationFlag,
45+
boolean nodeLabelStratification,
4446
Concurrency concurrency, double samplingRatio
4547
) {
4648
if (nodeLabelStratification) {
4749
var nodeLabelHistogram = NodeLabelHistogram.compute(
4850
inputGraph,
4951
concurrency,
50-
progressTracker
52+
progressTracker,
53+
terminationFlag
5154
);
5255

5356
return new SeenNodes.SeenNodesByLabelSet(inputGraph, nodeLabelHistogram, samplingRatio);

graph-sampling/src/main/java/org/neo4j/gds/graphsampling/samplers/rw/Walker.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2828
import org.neo4j.gds.graphsampling.samplers.SeenNodes;
2929
import org.neo4j.gds.graphsampling.samplers.rw.rwr.RandomWalkWithRestarts;
30+
import org.neo4j.gds.termination.TerminationFlag;
3031

3132
import java.util.Optional;
3233
import java.util.SplittableRandom;
@@ -41,6 +42,7 @@ public class Walker implements Runnable {
4142
protected final Graph inputGraph;
4243
private final double restartProbability;
4344
protected final ProgressTracker progressTracker;
45+
private final TerminationFlag terminationFlag;
4446

4547
private final LongSet startNodesUsed;
4648

@@ -55,6 +57,7 @@ public Walker(
5557
Graph inputGraph,
5658
double restartProbability,
5759
ProgressTracker progressTracker,
60+
TerminationFlag terminationFlag,
5861
NextNodeStrategy nextNodeStrategy
5962
) {
6063
this.seenNodes = seenNodes;
@@ -65,6 +68,7 @@ public Walker(
6568
this.inputGraph = inputGraph;
6669
this.restartProbability = restartProbability;
6770
this.progressTracker = progressTracker;
71+
this.terminationFlag = terminationFlag;
6872
this.startNodesUsed = new LongHashSet();
6973
this.nextNodeStrategy = nextNodeStrategy;
7074
}
@@ -78,7 +82,7 @@ public void run() {
7882
int nodesConsidered = 1;
7983
int walksLeft = (int) Math.round(walkQualities.nodeQuality(currentStartNodePosition) * RandomWalkWithRestarts.MAX_WALKS_PER_START);
8084

81-
while (!seenNodes.hasSeenEnough()) {
85+
while (!seenNodes.hasSeenEnough() && terminationFlag.running()) {
8286
if (seenNodes.addNode(currentNode)) {
8387
addedNodes++;
8488
}
@@ -118,6 +122,7 @@ public void run() {
118122
nodesConsidered++;
119123
}
120124
}
125+
terminationFlag.assertRunning();
121126
}
122127

123128
private double computeDegree(long currentNode) {

0 commit comments

Comments
 (0)