Skip to content

Commit b933ef5

Browse files
authored
Merge pull request #9715 from lassewesth/lpppp1
migrate link prediction add steps
2 parents 7baf1dc + e7a31c6 commit b933ef5

File tree

14 files changed

+195
-79
lines changed

14 files changed

+195
-79
lines changed

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/LinkPredictionTrainingPipeline.java

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import org.neo4j.gds.config.RelationshipWeightConfig;
2424
import org.neo4j.gds.config.ToMapConvertible;
2525
import org.neo4j.gds.core.model.Model;
26-
import org.neo4j.gds.executor.ExecutionContext;
26+
import org.neo4j.gds.core.model.ModelCatalog;
2727
import org.neo4j.gds.ml.pipeline.ExecutableNodePropertyStep;
2828
import org.neo4j.gds.ml.pipeline.TrainingPipeline;
2929
import org.neo4j.gds.settings.Neo4jSettings;
@@ -86,12 +86,12 @@ public void specificValidateBeforeExecution(GraphStore graphStore) {
8686
}
8787
}
8888

89-
public Map<String, List<String>> tasksByRelationshipProperty(ExecutionContext executionContext) {
89+
public Map<String, List<String>> tasksByRelationshipProperty(ModelCatalog modelCatalog, String username) {
9090
Map<String, List<String>> tasksByRelationshipProperty = new HashMap<>();
9191

9292
for (ExecutableNodePropertyStep existingStep : nodePropertySteps()) {
9393
Map<String, Object> config = existingStep.config();
94-
Optional<String> maybeProperty = extractRelationshipProperty(executionContext, config);
94+
Optional<String> maybeProperty = extractRelationshipProperty(config, modelCatalog, username);
9595

9696
maybeProperty.ifPresent(property -> {
9797
var tasks = tasksByRelationshipProperty.computeIfAbsent(property, key -> new ArrayList<>());
@@ -102,16 +102,17 @@ public Map<String, List<String>> tasksByRelationshipProperty(ExecutionContext ex
102102
return tasksByRelationshipProperty;
103103
}
104104

105-
private static Optional<String> extractRelationshipProperty(
106-
ExecutionContext executionContext,
107-
Map<String, Object> config
105+
private Optional<String> extractRelationshipProperty(
106+
Map<String, Object> config,
107+
ModelCatalog modelCatalog,
108+
String username
108109
) {
109110
if (config.containsKey(RELATIONSHIP_WEIGHT_PROPERTY)) {
110111
var existingProperty = (String) config.get(RELATIONSHIP_WEIGHT_PROPERTY);
111112
return Optional.of(existingProperty);
112113
} else if (config.containsKey(MODEL_NAME_KEY)) {
113-
return Optional.ofNullable(executionContext.modelCatalog().getUntyped(
114-
executionContext.username(),
114+
return Optional.ofNullable(modelCatalog.getUntyped(
115+
username,
115116
((String) config.get(MODEL_NAME_KEY))
116117
))
117118
.map(Model::trainConfig)
@@ -122,8 +123,10 @@ private static Optional<String> extractRelationshipProperty(
122123
return Optional.empty();
123124
}
124125

125-
public Optional<String> relationshipWeightProperty(ExecutionContext executionContext) {
126-
var relationshipWeightPropertySet = tasksByRelationshipProperty(executionContext).entrySet();
126+
public Optional<String> relationshipWeightProperty(ModelCatalog modelCatalog, String username) {
127+
var relationshipWeightPropertySet = tasksByRelationshipProperty(
128+
modelCatalog, username
129+
).entrySet();
127130
return relationshipWeightPropertySet.isEmpty()
128131
? Optional.empty()
129132
: Optional.of(relationshipWeightPropertySet.iterator().next().getKey());

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionTrainPipelineExecutor.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ public static MemoryEstimation estimate(
112112
var splitEstimations = splitEstimation(
113113
pipeline.splitConfig(),
114114
configuration.targetRelationshipType(),
115-
pipeline.relationshipWeightProperty(executionContext)
115+
pipeline.relationshipWeightProperty(executionContext.modelCatalog(), executionContext.username())
116116
);
117117

118118
MemoryEstimation maxOverNodePropertySteps = NodePropertyStepExecutor.estimateNodePropertySteps(
@@ -156,7 +156,7 @@ public Map<DatasetSplits, PipelineGraphFilter> generateDatasetSplitGraphFilters(
156156

157157
@Override
158158
public void splitDatasets() {
159-
this.linkPredictionRelationshipSampler.splitAndSampleRelationships(pipeline.relationshipWeightProperty(executionContext));
159+
this.linkPredictionRelationshipSampler.splitAndSampleRelationships(pipeline.relationshipWeightProperty(executionContext.modelCatalog(), executionContext.username()));
160160
}
161161

162162
@Override

pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/LinkPredictionTrainingPipelineTest.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -212,13 +212,13 @@ public boolean containsDependency(Class<?> type) {
212212

213213
var pipeline = new LinkPredictionTrainingPipeline();
214214

215-
assertThat(pipeline.relationshipWeightProperty(executionContext)).isEmpty();
215+
assertThat(pipeline.relationshipWeightProperty(executionContext.modelCatalog(), executionContext.username())).isEmpty();
216216

217217
var step = new TestNodePropertyStep(Map.of("relationshipWeightProperty", "myWeight"));
218218

219219
pipeline.addNodePropertyStep(step);
220220

221-
assertThat(pipeline.relationshipWeightProperty(executionContext)).isPresent().get().isEqualTo("myWeight");
221+
assertThat(pipeline.relationshipWeightProperty(executionContext.modelCatalog(), executionContext.username())).isPresent().get().isEqualTo("myWeight");
222222
}
223223

224224
@Test
@@ -266,13 +266,13 @@ public boolean containsDependency(Class<?> type) {
266266

267267
var pipeline = new LinkPredictionTrainingPipeline();
268268

269-
assertThat(pipeline.relationshipWeightProperty(executionContext)).isEmpty();
269+
assertThat(pipeline.relationshipWeightProperty(executionContext.modelCatalog(), executionContext.username())).isEmpty();
270270

271271
var step = new TestNodePropertyStep(Map.of("modelName", modelName));
272272

273273
pipeline.addNodePropertyStep(step);
274274

275-
assertThat(pipeline.relationshipWeightProperty(executionContext)).isPresent().get().isEqualTo("derivedWeight");
275+
assertThat(pipeline.relationshipWeightProperty(executionContext.modelCatalog(), executionContext.username())).isPresent().get().isEqualTo("derivedWeight");
276276
}
277277

278278
@Test
@@ -320,13 +320,13 @@ public boolean containsDependency(Class<?> type) {
320320

321321
var pipeline = new LinkPredictionTrainingPipeline();
322322

323-
assertThat(pipeline.relationshipWeightProperty(executionContext)).isEmpty();
323+
assertThat(pipeline.relationshipWeightProperty(executionContext.modelCatalog(), executionContext.username())).isEmpty();
324324

325325
var step = new TestNodePropertyStep(Map.of("modelName", modelName));
326326

327327
pipeline.addNodePropertyStep(step);
328328

329-
assertThat(pipeline.relationshipWeightProperty(executionContext)).isEmpty();
329+
assertThat(pipeline.relationshipWeightProperty(executionContext.modelCatalog(), executionContext.username())).isEmpty();
330330
}
331331

332332
private static class TestNodePropertyStep implements ExecutableNodePropertyStep {

proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineAddStepProcs.java

Lines changed: 7 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -21,24 +21,25 @@
2121

2222
import org.neo4j.gds.BaseProc;
2323
import org.neo4j.gds.core.CypherMapWrapper;
24-
import org.neo4j.gds.ml.pipeline.NodePropertyStepFactory;
2524
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
2625
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkFeatureStepFactory;
2726
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline;
2827
import org.neo4j.gds.ml.pipeline.linkPipeline.linkfunctions.LinkFeatureStepConfigurationImpl;
28+
import org.neo4j.gds.procedures.GraphDataScienceProcedures;
29+
import org.neo4j.gds.procedures.pipelines.PipelineInfoResult;
30+
import org.neo4j.procedure.Context;
2931
import org.neo4j.procedure.Description;
3032
import org.neo4j.procedure.Name;
3133
import org.neo4j.procedure.Procedure;
3234

3335
import java.util.Map;
34-
import java.util.stream.Collectors;
3536
import java.util.stream.Stream;
3637

37-
import static org.neo4j.gds.config.RelationshipWeightConfig.RELATIONSHIP_WEIGHT_PROPERTY;
38-
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
3938
import static org.neo4j.procedure.Mode.READ;
4039

4140
public class LinkPredictionPipelineAddStepProcs extends BaseProc {
41+
@Context
42+
public GraphDataScienceProcedures facade;
4243

4344
@Procedure(name = "gds.beta.pipeline.linkPrediction.addNodeProperty", mode = READ)
4445
@Description("Add a node property step to an existing link prediction pipeline.")
@@ -47,12 +48,7 @@ public Stream<PipelineInfoResult> addNodeProperty(
4748
@Name("procedureName") String taskName,
4849
@Name("procedureConfiguration") Map<String, Object> procedureConfig
4950
) {
50-
var pipeline = PipelineCatalog.getTyped(username(), pipelineName, LinkPredictionTrainingPipeline.class);
51-
validateRelationshipProperty(pipeline, procedureConfig);
52-
53-
pipeline.addNodePropertyStep(NodePropertyStepFactory.createNodePropertyStep(taskName, procedureConfig));
54-
55-
return Stream.of(new PipelineInfoResult(pipelineName, pipeline));
51+
return facade.pipelines().linkPrediction().addNodeProperty(pipelineName, taskName, procedureConfig);
5652
}
5753

5854
@Procedure(name = "gds.beta.pipeline.linkPrediction.addFeature", mode = READ)
@@ -68,33 +64,6 @@ public Stream<PipelineInfoResult> addFeature(
6864

6965
pipeline.addFeatureStep(LinkFeatureStepFactory.create(featureType, parsedConfig));
7066

71-
return Stream.of(new PipelineInfoResult(pipelineName, pipeline));
72-
}
73-
74-
// check if adding would result in more than one relationshipWeightProperty
75-
private void validateRelationshipProperty(
76-
LinkPredictionTrainingPipeline pipeline,
77-
Map<String, Object> procedureConfig
78-
) {
79-
if (!procedureConfig.containsKey(RELATIONSHIP_WEIGHT_PROPERTY)) return;
80-
var maybeRelationshipProperty = pipeline.relationshipWeightProperty(executionContext());
81-
if (maybeRelationshipProperty.isEmpty()) return;
82-
var relationshipProperty = maybeRelationshipProperty.get();
83-
var property = (String) procedureConfig.get(RELATIONSHIP_WEIGHT_PROPERTY);
84-
if (relationshipProperty.equals(property)) return;
85-
86-
String tasks = pipeline.tasksByRelationshipProperty(executionContext())
87-
.get(relationshipProperty)
88-
.stream()
89-
.map(s -> "`" + s + "`")
90-
.collect(Collectors.joining(", "));
91-
throw new IllegalArgumentException(formatWithLocale(
92-
"Node property steps added to a pipeline may not have different non-null values for `%s`. " +
93-
"Pipeline already contains tasks %s which use the value `%s`.",
94-
RELATIONSHIP_WEIGHT_PROPERTY,
95-
tasks,
96-
relationshipProperty
97-
));
67+
return Stream.of(PipelineInfoResult.create(pipelineName, pipeline));
9868
}
99-
10069
}

proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineAddTrainerMethodProcs.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.neo4j.gds.ml.models.randomforest.RandomForestClassifierTrainerConfig;
2929
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
3030
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline;
31+
import org.neo4j.gds.procedures.pipelines.PipelineInfoResult;
3132
import org.neo4j.procedure.Description;
3233
import org.neo4j.procedure.Internal;
3334
import org.neo4j.procedure.Name;
@@ -56,7 +57,7 @@ public Stream<PipelineInfoResult> addLogisticRegression(
5657
tunableTrainerConfig
5758
);
5859

59-
return Stream.of(new PipelineInfoResult(pipelineName, pipeline));
60+
return Stream.of(PipelineInfoResult.create(pipelineName, pipeline));
6061
}
6162

6263
@Procedure(name = "gds.beta.pipeline.linkPrediction.addRandomForest", mode = READ)
@@ -75,7 +76,7 @@ public Stream<PipelineInfoResult> addRandomForest(
7576
tunableTrainerConfig
7677
);
7778

78-
return Stream.of(new PipelineInfoResult(pipelineName, pipeline));
79+
return Stream.of(PipelineInfoResult.create(pipelineName, pipeline));
7980
}
8081

8182
@Procedure(name = "gds.alpha.pipeline.linkPrediction.addRandomForest", mode = READ, deprecatedBy = "gds.beta.pipeline.linkPrediction.addRandomForest")
@@ -109,6 +110,6 @@ public Stream<PipelineInfoResult> addMLP(
109110

110111
pipeline.addTrainerConfig(TunableTrainerConfig.of(mlpClassifierConfig, TrainingMethod.MLPClassification));
111112

112-
return Stream.of(new PipelineInfoResult(pipelineName, pipeline));
113+
return Stream.of(PipelineInfoResult.create(pipelineName, pipeline));
113114
}
114115
}

proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineConfigureAutoTuningProc.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.neo4j.gds.ml.pipeline.PipelineCompanion;
2424
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
2525
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline;
26+
import org.neo4j.gds.procedures.pipelines.PipelineInfoResult;
2627
import org.neo4j.procedure.Description;
2728
import org.neo4j.procedure.Name;
2829
import org.neo4j.procedure.Procedure;
@@ -42,7 +43,7 @@ public Stream<PipelineInfoResult> configureAutoTuning(@Name("pipelineName") Stri
4243
username(),
4344
pipelineName,
4445
configMap,
45-
pipeline -> new PipelineInfoResult(pipelineName, (LinkPredictionTrainingPipeline) pipeline)
46+
pipeline -> PipelineInfoResult.create(pipelineName, (LinkPredictionTrainingPipeline) pipeline)
4647
);
4748
}
4849
}

proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineConfigureSplitProc.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
2525
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionSplitConfig;
2626
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline;
27+
import org.neo4j.gds.procedures.pipelines.PipelineInfoResult;
2728
import org.neo4j.procedure.Description;
2829
import org.neo4j.procedure.Name;
2930
import org.neo4j.procedure.Procedure;
@@ -47,6 +48,6 @@ public Stream<PipelineInfoResult> configureSplit(@Name("pipelineName") String pi
4748

4849
pipeline.setSplitConfig(config);
4950

50-
return Stream.of(new PipelineInfoResult(pipelineName, pipeline));
51+
return Stream.of(PipelineInfoResult.create(pipelineName, pipeline));
5152
}
5253
}

proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineCreateProc.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.neo4j.gds.core.StringIdentifierValidations;
2424
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
2525
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline;
26+
import org.neo4j.gds.procedures.pipelines.PipelineInfoResult;
2627
import org.neo4j.procedure.Description;
2728
import org.neo4j.procedure.Name;
2829
import org.neo4j.procedure.Procedure;
@@ -42,7 +43,7 @@ public Stream<PipelineInfoResult> create(@Name("pipelineName") String input) {
4243
LinkPredictionTrainingPipeline pipeline = new LinkPredictionTrainingPipeline();
4344
PipelineCatalog.set(username(), pipelineName, pipeline);
4445

45-
return Stream.of(new PipelineInfoResult(pipelineName, pipeline));
46+
return Stream.of(PipelineInfoResult.create(pipelineName, pipeline));
4647
}
4748

4849
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.procedures.pipelines;
21+
22+
import java.util.Map;
23+
import java.util.stream.Stream;
24+
25+
public class LinkPredictionFacade {
26+
private final PipelineApplications pipelineApplications;
27+
28+
LinkPredictionFacade(PipelineApplications pipelineApplications) {
29+
this.pipelineApplications = pipelineApplications;
30+
}
31+
32+
public Stream<PipelineInfoResult> addNodeProperty(
33+
String pipelineNameAsString,
34+
String taskName,
35+
Map<String, Object> procedureConfig
36+
) {
37+
var pipelineName = PipelineName.parse(pipelineNameAsString);
38+
39+
var pipeline = pipelineApplications.addNodePropertyToLinkPredictionPipeline(
40+
pipelineName,
41+
taskName,
42+
procedureConfig
43+
);
44+
45+
var result = PipelineInfoResult.create(pipelineName.value, pipeline);
46+
47+
return Stream.of(result);
48+
}
49+
}

0 commit comments

Comments
 (0)