Skip to content

Commit f4f3d01

Browse files
authored
Merge pull request #9684 from orazve/termination-check
Using termination flag for GraphSage
2 parents 5025553 + 77d8bef commit f4f3d01

File tree

17 files changed

+469
-46
lines changed

17 files changed

+469
-46
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/graphsage/BatchSampler.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.neo4j.gds.core.utils.partition.PartitionUtils;
2929
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
3030
import org.neo4j.gds.ml.core.samplers.WeightedUniformSampler;
31+
import org.neo4j.gds.termination.TerminationFlag;
3132

3233
import java.util.Arrays;
3334
import java.util.List;
@@ -39,17 +40,21 @@ final class BatchSampler {
3940
public static final double DEGREE_SMOOTHING_FACTOR = 0.75;
4041
private final Graph graph;
4142
private final ProgressTracker progressTracker;
43+
private final TerminationFlag terminationFlag;
4244

43-
BatchSampler(Graph graph, ProgressTracker progressTracker) {
45+
BatchSampler(Graph graph, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
4446
this.graph = graph;
4547
this.progressTracker = progressTracker;
48+
this.terminationFlag = terminationFlag;
4649
}
4750

4851
List<long[]> extendedBatches(int batchSize, int searchDepth, long randomSeed) {
4952
return PartitionUtils.rangePartitionWithBatchSize(
5053
graph.nodeCount(),
5154
batchSize,
5255
batch -> {
56+
terminationFlag.assertRunning();
57+
5358
var localSeed = Math.toIntExact(Math.floorDiv(batch.startNode(), graph.nodeCount())) + randomSeed;
5459
long[] extendedBatch = sampleNeighborAndNegativeNodePerBatchNode(batch, searchDepth, localSeed);
5560

algo/src/main/java/org/neo4j/gds/embeddings/graphsage/GraphSageEmbeddingsGenerator.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,17 @@
2020
package org.neo4j.gds.embeddings.graphsage;
2121

2222
import org.neo4j.gds.api.Graph;
23+
import org.neo4j.gds.collections.ha.HugeObjectArray;
2324
import org.neo4j.gds.core.concurrency.Concurrency;
2425
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
25-
import org.neo4j.gds.collections.ha.HugeObjectArray;
2626
import org.neo4j.gds.core.utils.partition.Partition;
2727
import org.neo4j.gds.core.utils.partition.PartitionUtils;
2828
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2929
import org.neo4j.gds.ml.core.ComputationContext;
3030
import org.neo4j.gds.ml.core.Variable;
3131
import org.neo4j.gds.ml.core.subgraph.SubGraph;
3232
import org.neo4j.gds.ml.core.tensor.Matrix;
33+
import org.neo4j.gds.termination.TerminationFlag;
3334

3435
import java.util.List;
3536
import java.util.Optional;
@@ -44,6 +45,7 @@ public class GraphSageEmbeddingsGenerator {
4445
private final long randomSeed;
4546
private final ExecutorService executor;
4647
private final ProgressTracker progressTracker;
48+
private final TerminationFlag terminationFlag;
4749

4850
public GraphSageEmbeddingsGenerator(
4951
Layer[] layers,
@@ -52,7 +54,8 @@ public GraphSageEmbeddingsGenerator(
5254
FeatureFunction featureFunction,
5355
Optional<Long> randomSeed,
5456
ExecutorService executor,
55-
ProgressTracker progressTracker
57+
ProgressTracker progressTracker,
58+
TerminationFlag terminationFlag
5659
) {
5760
this.layers = layers;
5861
this.batchSize = batchSize;
@@ -61,6 +64,7 @@ public GraphSageEmbeddingsGenerator(
6164
this.randomSeed = randomSeed.orElseGet(() -> ThreadLocalRandom.current().nextLong());
6265
this.executor = executor;
6366
this.progressTracker = progressTracker;
67+
this.terminationFlag = terminationFlag;
6468
}
6569

6670
public HugeObjectArray<double[]> makeEmbeddings(
@@ -77,7 +81,7 @@ public HugeObjectArray<double[]> makeEmbeddings(
7781
var tasks = PartitionUtils.rangePartitionWithBatchSize(
7882
graph.nodeCount(),
7983
batchSize,
80-
partition -> createEmbeddings(graph.concurrentCopy(), partition, features, result)
84+
partition -> createEmbeddings(graph.concurrentCopy(), partition, features, result, terminationFlag)
8185
);
8286

8387
RunWithConcurrency.builder()
@@ -95,9 +99,11 @@ private Runnable createEmbeddings(
9599
Graph graph,
96100
Partition partition,
97101
HugeObjectArray<double[]> features,
98-
HugeObjectArray<double[]> result
102+
HugeObjectArray<double[]> result,
103+
TerminationFlag terminationFlag
99104
) {
100105
return () -> {
106+
terminationFlag.assertRunning();
101107
List<SubGraph> subGraphs = GraphSageHelper.subGraphsPerLayer(
102108
graph,
103109
partition.stream().toArray(),

algo/src/main/java/org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainer.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import org.neo4j.gds.ml.core.tensor.Matrix;
4141
import org.neo4j.gds.ml.core.tensor.Scalar;
4242
import org.neo4j.gds.ml.core.tensor.Tensor;
43+
import org.neo4j.gds.termination.TerminationFlag;
4344

4445
import java.util.ArrayList;
4546
import java.util.Arrays;
@@ -66,15 +67,17 @@ public class GraphSageModelTrainer {
6667
private final ExecutorService executor;
6768
private final ProgressTracker progressTracker;
6869
private final Layer[] layers;
70+
private final TerminationFlag terminationFlag;
6971

70-
public GraphSageModelTrainer(GraphSageTrainParameters parameters, int featureDimension, ExecutorService executor, ProgressTracker progressTracker) {
71-
this(parameters, executor, progressTracker, new SingleLabelFeatureFunction(), Collections.emptyList(), featureDimension);
72+
public GraphSageModelTrainer(GraphSageTrainParameters parameters, int featureDimension, ExecutorService executor, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
73+
this(parameters, executor, progressTracker, terminationFlag, new SingleLabelFeatureFunction(), Collections.emptyList(), featureDimension);
7274
}
7375

7476
public GraphSageModelTrainer(
7577
GraphSageTrainParameters parameters,
7678
ExecutorService executor,
7779
ProgressTracker progressTracker,
80+
TerminationFlag terminationFlag,
7881
FeatureFunction featureFunction,
7982
Collection<Weights<Matrix>> labelProjectionWeights,
8083
int featureDimension
@@ -84,6 +87,7 @@ public GraphSageModelTrainer(
8487
this.labelProjectionWeights = labelProjectionWeights;
8588
this.executor = executor;
8689
this.progressTracker = progressTracker;
90+
this.terminationFlag = terminationFlag;
8791
this.randomSeed = parameters.randomSeed().orElseGet(() -> ThreadLocalRandom.current().nextLong());
8892
this.layers = parameters.layerConfigs(featureDimension)
8993
.stream()
@@ -114,7 +118,7 @@ public ModelTrainResult train(Graph graph, HugeObjectArray<double[]> features) {
114118

115119
progressTracker.beginSubTask("Prepare batches");
116120

117-
var batchSampler = new BatchSampler(graph, progressTracker);
121+
var batchSampler = new BatchSampler(graph, progressTracker, terminationFlag);
118122

119123
List<long[]> extendedBatches = batchSampler
120124
.extendedBatches(parameters.batchSize(), parameters.searchDepth(), randomSeed);
@@ -135,6 +139,7 @@ public ModelTrainResult train(Graph graph, HugeObjectArray<double[]> features) {
135139

136140
for (int epoch = 1; epoch <= epochs && !converged; epoch++) {
137141
progressTracker.beginSubTask("Epoch");
142+
terminationFlag.assertRunning();
138143
// also tried using random.nextLong() but this somehow had a worse quality
139144
long epochLocalSeed = epoch + randomSeed;
140145

@@ -254,6 +259,7 @@ private EpochResult trainEpoch(
254259
int maxIterations = parameters.maxIterations();
255260
for (; iteration <= maxIterations; iteration++) {
256261
progressTracker.beginSubTask("Iteration");
262+
terminationFlag.assertRunning();
257263

258264
var sampledBatchTasks = sampledBatchTaskSupplier.get();
259265

algo/src/main/java/org/neo4j/gds/embeddings/graphsage/algo/GraphSage.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.neo4j.gds.embeddings.graphsage.GraphSageModelTrainer;
3131
import org.neo4j.gds.embeddings.graphsage.Layer;
3232
import org.neo4j.gds.embeddings.graphsage.ModelData;
33+
import org.neo4j.gds.termination.TerminationFlag;
3334

3435
import java.util.concurrent.ExecutorService;
3536

@@ -52,14 +53,16 @@ public GraphSage(
5253
Concurrency concurrency,
5354
int batchSize,
5455
ExecutorService executor,
55-
ProgressTracker progressTracker
56+
ProgressTracker progressTracker,
57+
TerminationFlag terminationFlag
5658
) {
5759
super(progressTracker);
5860
this.graph = graph;
5961
this.concurrency = concurrency;
6062
this.batchSize = batchSize;
6163
this.model = model;
6264
this.executor = executor;
65+
this.terminationFlag = terminationFlag;
6366
}
6467

6568
@Override
@@ -73,7 +76,8 @@ public GraphSageResult compute() {
7376
model.data().featureFunction(),
7477
model.trainConfig().randomSeed(),
7578
executor,
76-
progressTracker
79+
progressTracker,
80+
terminationFlag
7781
);
7882

7983
GraphSageTrainConfig trainConfig = model.trainConfig();

algo/src/main/java/org/neo4j/gds/embeddings/graphsage/algo/GraphSageAlgorithmFactory.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
3333
import org.neo4j.gds.embeddings.graphsage.GraphSageModelTrainer.GraphSageTrainMetrics;
3434
import org.neo4j.gds.embeddings.graphsage.ModelData;
35+
import org.neo4j.gds.termination.TerminationFlag;
3536

3637
import static org.neo4j.gds.embeddings.graphsage.algo.GraphSageModelResolver.resolveModel;
3738
import static org.neo4j.gds.ml.core.EmbeddingUtils.validateRelationshipWeightPropertyValue;
@@ -63,7 +64,8 @@ public GraphSage build(
6364
parameters.concurrency(),
6465
parameters.batchSize(),
6566
executorService,
66-
progressTracker
67+
progressTracker,
68+
TerminationFlag.RUNNING_TRUE
6769
);
6870
}
6971

algo/src/main/java/org/neo4j/gds/embeddings/graphsage/algo/GraphSageTrain.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@
2424
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2525
import org.neo4j.gds.embeddings.graphsage.GraphSageModelTrainer;
2626
import org.neo4j.gds.embeddings.graphsage.ModelData;
27+
import org.neo4j.gds.termination.TerminationFlag;
2728

2829
public abstract class GraphSageTrain extends Algorithm<Model<ModelData, GraphSageTrainConfig, GraphSageModelTrainer.GraphSageTrainMetrics>> {
2930

30-
protected GraphSageTrain(ProgressTracker progressTracker) {
31+
protected GraphSageTrain(ProgressTracker progressTracker, TerminationFlag terminationFlag) {
3132
super(progressTracker);
33+
setTerminationFlag(terminationFlag);
3234
}
3335

3436
}

algo/src/main/java/org/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainAlgorithmFactory.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,13 @@
2323
import org.neo4j.gds.api.Graph;
2424
import org.neo4j.gds.compat.GdsVersionInfoProvider;
2525
import org.neo4j.gds.core.concurrency.DefaultPool;
26-
import org.neo4j.gds.embeddings.graphsage.TrainConfigTransformer;
27-
import org.neo4j.gds.mem.MemoryEstimation;
2826
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2927
import org.neo4j.gds.core.utils.progress.tasks.Task;
3028
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
3129
import org.neo4j.gds.embeddings.graphsage.GraphSageModelTrainer;
30+
import org.neo4j.gds.embeddings.graphsage.TrainConfigTransformer;
31+
import org.neo4j.gds.mem.MemoryEstimation;
32+
import org.neo4j.gds.termination.TerminationFlag;
3233

3334
import static org.neo4j.gds.ml.core.EmbeddingUtils.validateRelationshipWeightPropertyValue;
3435

@@ -59,8 +60,8 @@ public GraphSageTrain build(
5960

6061
var parameters = TrainConfigTransformer.toParameters(configuration);
6162
return configuration.isMultiLabel()
62-
? new MultiLabelGraphSageTrain(graph, parameters, configuration.projectedFeatureDimension().get(), executorService, progressTracker, gdsVersion, configuration)
63-
: new SingleLabelGraphSageTrain(graph, parameters, executorService, progressTracker, gdsVersion, configuration);
63+
? new MultiLabelGraphSageTrain(graph, parameters, configuration.projectedFeatureDimension().get(), executorService, progressTracker, TerminationFlag.RUNNING_TRUE, gdsVersion, configuration)
64+
: new SingleLabelGraphSageTrain(graph, parameters, executorService, progressTracker, TerminationFlag.RUNNING_TRUE, gdsVersion, configuration);
6465
}
6566

6667
public MemoryEstimation memoryEstimation(GraphSageTrainMemoryEstimateParameters parameters) {

algo/src/main/java/org/neo4j/gds/embeddings/graphsage/algo/MultiLabelGraphSageTrain.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.neo4j.gds.embeddings.graphsage.MultiLabelFeatureFunction;
3030
import org.neo4j.gds.ml.core.functions.Weights;
3131
import org.neo4j.gds.ml.core.tensor.Matrix;
32+
import org.neo4j.gds.termination.TerminationFlag;
3233

3334
import java.util.Map;
3435
import java.util.Optional;
@@ -57,10 +58,11 @@ public MultiLabelGraphSageTrain(
5758
int projectedFeatureDimension,
5859
ExecutorService executor,
5960
ProgressTracker progressTracker,
61+
TerminationFlag terminationFlag,
6062
String gdsVersion,
6163
GraphSageTrainConfig config // TODO: Last trace of UI config in here--Once we attach Parameters to Models we can lose this too
6264
) {
63-
super(progressTracker);
65+
super(progressTracker, terminationFlag);
6466
this.graph = graph;
6567
this.featureDimension = projectedFeatureDimension;
6668
this.parameters = parameters;
@@ -72,6 +74,7 @@ public MultiLabelGraphSageTrain(
7274
@Override
7375
public Model<ModelData, GraphSageTrainConfig, GraphSageModelTrainer.GraphSageTrainMetrics> compute() {
7476
progressTracker.beginSubTask("GraphSageTrain");
77+
terminationFlag.assertRunning();
7578
var multiLabelFeatureExtractors = GraphSageHelper.multiLabelFeatureExtractors(
7679
graph,
7780
parameters.featureProperties()
@@ -82,6 +85,7 @@ public Model<ModelData, GraphSageTrainConfig, GraphSageModelTrainer.GraphSageTra
8285
parameters,
8386
executor,
8487
progressTracker,
88+
terminationFlag,
8589
multiLabelFeatureFunction,
8690
multiLabelFeatureFunction.weightsByLabel().values(),
8791
featureDimension

algo/src/main/java/org/neo4j/gds/embeddings/graphsage/algo/SingleLabelGraphSageTrain.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.neo4j.gds.embeddings.graphsage.ModelData;
2727
import org.neo4j.gds.embeddings.graphsage.SingleLabelFeatureFunction;
2828
import org.neo4j.gds.ml.core.features.FeatureExtraction;
29+
import org.neo4j.gds.termination.TerminationFlag;
2930

3031
import java.util.concurrent.ExecutorService;
3132

@@ -45,10 +46,11 @@ public SingleLabelGraphSageTrain(
4546
GraphSageTrainParameters parameters,
4647
ExecutorService executor,
4748
ProgressTracker progressTracker,
49+
TerminationFlag terminationFlag,
4850
String gdsVersion,
4951
GraphSageTrainConfig config // TODO: Last trace of UI config in here--Once we attach Parameters to Models we can lose this too
5052
) {
51-
super(progressTracker);
53+
super(progressTracker, terminationFlag);
5254
this.graph = graph;
5355
this.parameters = parameters;
5456
this.executor = executor;
@@ -65,7 +67,8 @@ public Model<ModelData, GraphSageTrainConfig, GraphSageModelTrainer.GraphSageTra
6567
parameters,
6668
featureDimension,
6769
executor,
68-
progressTracker
70+
progressTracker,
71+
terminationFlag
6972
);
7073

7174
GraphSageModelTrainer.ModelTrainResult trainResult = graphSageModel.train(

algo/src/test/java/org/neo4j/gds/embeddings/graphsage/BatchSamplerTest.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.neo4j.gds.extension.GdlGraph;
2929
import org.neo4j.gds.extension.Inject;
3030
import org.neo4j.gds.gdl.GdlFactory;
31+
import org.neo4j.gds.termination.TerminationFlag;
3132

3233
import java.util.function.Function;
3334
import java.util.stream.Collectors;
@@ -49,7 +50,7 @@ void sampleDenseGraph() {
4950
Partition allNodes = Partition.of(0, 2);
5051
int searchDepth = 3;
5152

52-
assertThat(new BatchSampler(clique, ProgressTracker.NULL_TRACKER).sampleNeighborAndNegativeNodePerBatchNode(allNodes, searchDepth, 42))
53+
assertThat(new BatchSampler(clique, ProgressTracker.NULL_TRACKER, TerminationFlag.RUNNING_TRUE).sampleNeighborAndNegativeNodePerBatchNode(allNodes, searchDepth, 42))
5354
.containsExactly(
5455
0L, 1L,
5556
// positive samples
@@ -75,8 +76,8 @@ void seededNegativeBatch() {
7576

7677
for (int i = 0; i < partitions.size(); i++) {
7778
var localSeed = i + seed;
78-
var negativeBatch = new BatchSampler(graph, ProgressTracker.NULL_TRACKER).negativeBatch(Math.toIntExact(partitions.get(i).nodeCount()), neighborsSet, localSeed);
79-
var otherNegativeBatch = new BatchSampler(graph, ProgressTracker.NULL_TRACKER).negativeBatch(Math.toIntExact(partitions.get(i).nodeCount()), neighborsSet, localSeed);
79+
var negativeBatch = new BatchSampler(graph, ProgressTracker.NULL_TRACKER, TerminationFlag.RUNNING_TRUE).negativeBatch(Math.toIntExact(partitions.get(i).nodeCount()), neighborsSet, localSeed);
80+
var otherNegativeBatch = new BatchSampler(graph, ProgressTracker.NULL_TRACKER, TerminationFlag.RUNNING_TRUE).negativeBatch(Math.toIntExact(partitions.get(i).nodeCount()), neighborsSet, localSeed);
8081

8182
assertThat(negativeBatch).containsExactlyElementsOf(otherNegativeBatch.boxed().collect(Collectors.toList()));
8283
}
@@ -96,8 +97,8 @@ void seededNeighborBatch() {
9697

9798
for (int i = 0; i < partitions.size(); i++) {
9899
var localSeed = i + seed;
99-
var neighborBatch = new BatchSampler(graph, ProgressTracker.NULL_TRACKER).neighborBatch(partitions.get(i), localSeed, searchDepth);
100-
var otherNeighborBatch = new BatchSampler(graph, ProgressTracker.NULL_TRACKER).neighborBatch(partitions.get(i), localSeed, searchDepth);
100+
var neighborBatch = new BatchSampler(graph, ProgressTracker.NULL_TRACKER, TerminationFlag.RUNNING_TRUE).neighborBatch(partitions.get(i), localSeed, searchDepth);
101+
var otherNeighborBatch = new BatchSampler(graph, ProgressTracker.NULL_TRACKER, TerminationFlag.RUNNING_TRUE).neighborBatch(partitions.get(i), localSeed, searchDepth);
101102
assertThat(neighborBatch).containsExactly(otherNeighborBatch);
102103
}
103104
}

0 commit comments

Comments
 (0)