Skip to content

Commit 811279c

Browse files
breakanalysisadamnsch
authored andcommitted
Tune LP train task volumes and include splitting and node property steps
Co-Authored-By: Adam Schill Collberg<adam.schill.collberg@protonmail.com>
1 parent becb552 commit 811279c

File tree

3 files changed

+12
-8
lines changed

3 files changed

+12
-8
lines changed

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionTrain.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,17 +102,17 @@ public static List<Task> progressTasks(
102102
) {
103103
var sizes = splitConfig.expectedSetSizes(relationshipCount);
104104
return List.of(
105-
Tasks.leaf("Extract train features", sizes.trainSize()),
105+
Tasks.leaf("Extract train features", sizes.trainSize() * 3),
106106
Tasks.iterativeFixed(
107107
"Select best model",
108-
() -> List.of(Tasks.leaf("Trial", splitConfig.validationFolds() * sizes.trainSize())),
108+
() -> List.of(Tasks.leaf("Trial", splitConfig.validationFolds() * sizes.trainSize() * 5)),
109109
numberOfModelSelectionTrials
110110
),
111-
ClassifierTrainer.progressTask("Train best model", sizes.trainSize()),
111+
ClassifierTrainer.progressTask("Train best model", sizes.trainSize() * 5),
112112
Tasks.leaf("Compute train metrics", sizes.trainSize()),
113113
Tasks.task(
114114
"Evaluate on test data",
115-
Tasks.leaf("Extract test features", sizes.testSize()),
115+
Tasks.leaf("Extract test features", sizes.testSize() * 3),
116116
Tasks.leaf("Compute test metrics", sizes.testSize())
117117
)
118118
);

proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/train/LinkPredictionTrainPipelineExecutor.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,15 @@ public LinkPredictionTrainPipelineExecutor(
8282
}
8383

8484
public static Task progressTask(String taskName, LinkPredictionTrainingPipeline pipeline, long relationshipCount) {
85+
var sizes = pipeline.splitConfig().expectedSetSizes(relationshipCount);
8586
return Tasks.task(taskName, new ArrayList<>() {{
86-
add(Tasks.leaf("Split relationships"));
87+
add(Tasks.leaf(
88+
"Split relationships",
89+
sizes.trainSize() + sizes.featureInputSize() + sizes.testSize() + sizes.testComplementSize()
90+
));
8791
add(Tasks.iterativeFixed(
8892
"Execute node property steps",
89-
() -> List.of(Tasks.leaf("Step")),
93+
() -> List.of(Tasks.leaf("Step", 10 * sizes.featureInputSize())),
9094
pipeline.nodePropertySteps().size()
9195
));
9296
addAll(LinkPredictionTrain.progressTasks(

proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/train/RelationshipSplitter.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ public void splitRelationships(
7171
Optional<Long> randomSeed,
7272
Optional<String> relationshipWeightProperty
7373
) {
74-
progressTracker.beginSubTask();
74+
progressTracker.beginSubTask("Split relationships");
7575

7676
splitConfig.validateAgainstGraphStore(graphStore);
7777

@@ -90,7 +90,7 @@ public void splitRelationships(
9090

9191
graphStore.deleteRelationships(RelationshipType.of(testComplementRelationshipType));
9292

93-
progressTracker.endSubTask();
93+
progressTracker.endSubTask("Split relationships");
9494
}
9595

9696
private void validateTestSplit(GraphStore graphStore) {

0 commit comments

Comments
 (0)