Skip to content

Commit c8ffffd

Browse files
committed
Bootstrap PushbackNodeEmbeddingsStreamProcedureFacade
1 parent 82c6122 commit c8ffffd

File tree

14 files changed

+533
-13
lines changed

14 files changed

+533
-13
lines changed

algorithms-compute-business-facade/src/main/java/org/neo4j/gds/embeddings/NodeEmbeddingComputeBusinessFacade.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ public <TR> CompletableFuture<TR> hashGnn(
9797
GraphParameters graphParameters,
9898
// FIXME: `relationshipTypes` is only used to create progress tracker tasks, and in there only the count...
9999
List<String> relationshipTypes,
100-
Optional<String> relationshipProperty,
101100
HashGNNParameters parameters,
102101
JobId jobId,
103102
boolean logProgress,
@@ -107,7 +106,7 @@ public <TR> CompletableFuture<TR> hashGnn(
107106
var graphResources = graphStoreCatalogService.fetchGraphResources(
108107
graphName,
109108
graphParameters,
110-
relationshipProperty,
109+
Optional.empty(),
111110
new FeaturePropertiesMustExistOnAllNodeLabels(parameters.featureProperties()),
112111
Optional.empty(),
113112
user,

applications/algorithms/node-embeddings/src/main/java/org/neo4j/gds/applications/algorithms/embeddings/NodeEmbeddingAlgorithmsEstimationModeBusinessFacade.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ public MemoryEstimateResult hashGnn(HashGNNConfig configuration, Object graphNam
125125
}
126126

127127
public MemoryEstimation node2Vec(Node2VecBaseConfig configuration) {
128-
return new Node2VecMemoryEstimateDefinition(Node2VecConfigTransformer.node2VecParameters(configuration)).memoryEstimation();
128+
return new Node2VecMemoryEstimateDefinition(Node2VecConfigTransformer.toParameters(configuration)).memoryEstimation();
129129
}
130130

131131
public MemoryEstimateResult node2Vec(Node2VecBaseConfig configuration, Object graphNameOrConfiguration) {

applications/algorithms/node-embeddings/src/main/java/org/neo4j/gds/applications/algorithms/embeddings/NodeEmbeddingBusinessAlgorithms.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ public FastRPResult fastRP(Graph graph, FastRPBaseConfig configuration) {
6969
}
7070

7171
Node2VecResult node2Vec(Graph graph, Node2VecBaseConfig configuration) {
72-
var params = Node2VecConfigTransformer.node2VecParameters(configuration);
72+
var params = Node2VecConfigTransformer.toParameters(configuration);
7373
var task = tasks.node2Vec(graph,params);
7474
var progressTracker = progressTrackerCreator.createProgressTracker(task, configuration);
7575

procedures/facade-api/configs/node-embeddings-configs/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecConfigTransformer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
public final class Node2VecConfigTransformer {
2323
private Node2VecConfigTransformer() {}
2424

25-
public static Node2VecParameters node2VecParameters(Node2VecBaseConfig config) {
25+
public static Node2VecParameters toParameters(Node2VecBaseConfig config) {
2626
var walkParameters = config.walkParameters();
2727

2828
var samplingWalkParameters = new SamplingWalkParameters(

procedures/facade-api/node-embeddings-facade-api/src/main/java/org/neo4j/gds/procedures/algorithms/embeddings/DefaultNodeEmbeddingsStreamResult.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,14 @@
2222
import java.util.ArrayList;
2323
import java.util.Arrays;
2424
import java.util.List;
25-
import java.util.stream.Collectors;
2625

2726
public record DefaultNodeEmbeddingsStreamResult(long nodeId, List<Double> embedding) {
2827

29-
static DefaultNodeEmbeddingsStreamResult create(long nodeId, double[] embeddings) {
30-
return new DefaultNodeEmbeddingsStreamResult(nodeId, Arrays.stream(embeddings).boxed().collect(
31-
Collectors.toList()));
28+
public static DefaultNodeEmbeddingsStreamResult create(long nodeId, double[] embeddings) {
29+
return new DefaultNodeEmbeddingsStreamResult(nodeId, Arrays.stream(embeddings).boxed().toList());
3230
}
3331

34-
static DefaultNodeEmbeddingsStreamResult create(long nodeId, float[] embeddingAsArray) {
32+
public static DefaultNodeEmbeddingsStreamResult create(long nodeId, float[] embeddingAsArray) {
3533
var embedding = new ArrayList<Double>(embeddingAsArray.length);
3634
for (var f : embeddingAsArray) {
3735
embedding.add((double) f);

procedures/pushback-procedures-facade/build.gradle

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ dependencies {
4242
// this is needed because of `MemoryEstimation`
4343
implementation project(':memory-usage')
4444

45+
implementation project(':ml-core')
46+
4547
implementation project(':neo4j-api')
4648

4749
implementation project(':path-finding-mutate-steps')

procedures/pushback-procedures-facade/src/main/java/org/neo4j/gds/procedures/algorithms/embeddings/PushbackNodeEmbeddingsProcedureFacade.java

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,22 @@
2020
package org.neo4j.gds.procedures.algorithms.embeddings;
2121

2222
import org.neo4j.gds.applications.algorithms.machinery.MemoryEstimateResult;
23+
import org.neo4j.gds.procedures.algorithms.embeddings.stream.PushbackNodeEmbeddingsStreamProcedureFacade;
2324
import org.neo4j.gds.procedures.algorithms.embeddings.stubs.NodeEmbeddingsStubs;
2425

2526
import java.util.Map;
2627
import java.util.stream.Stream;
2728

2829
public class PushbackNodeEmbeddingsProcedureFacade implements NodeEmbeddingsProcedureFacade {
2930

31+
private final PushbackNodeEmbeddingsStreamProcedureFacade streamProcedureFacade;
32+
33+
public PushbackNodeEmbeddingsProcedureFacade(
34+
PushbackNodeEmbeddingsStreamProcedureFacade streamProcedureFacade
35+
) {
36+
this.streamProcedureFacade = streamProcedureFacade;
37+
}
38+
3039
@Override
3140
public NodeEmbeddingsStubs nodeEmbeddingStubs() {
3241
return null;
@@ -47,7 +56,7 @@ public Stream<MemoryEstimateResult> fastRPStatsEstimate(
4756

4857
@Override
4958
public Stream<DefaultNodeEmbeddingsStreamResult> fastRPStream(String graphName, Map<String, Object> configuration) {
50-
return Stream.empty();
59+
return streamProcedureFacade.fastRP(graphName, configuration);
5160
}
5261

5362
@Override
@@ -150,7 +159,7 @@ public Stream<DefaultNodeEmbeddingsStreamResult> hashGnnStream(
150159
String graphName,
151160
Map<String, Object> configuration
152161
) {
153-
return Stream.empty();
162+
return streamProcedureFacade.hashGnn(graphName, configuration);
154163
}
155164

156165
@Override
@@ -192,7 +201,7 @@ public Stream<DefaultNodeEmbeddingsStreamResult> node2VecStream(
192201
String graphName,
193202
Map<String, Object> configuration
194203
) {
195-
return Stream.empty();
204+
return streamProcedureFacade.node2Vec(graphName, configuration);
196205
}
197206

198207
@Override
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.procedures.algorithms.embeddings.stream;
21+
22+
import org.neo4j.gds.api.Graph;
23+
import org.neo4j.gds.api.IdMap;
24+
import org.neo4j.gds.api.properties.nodes.NodePropertyValuesAdapter;
25+
import org.neo4j.gds.embeddings.fastrp.FastRPResult;
26+
import org.neo4j.gds.procedures.algorithms.embeddings.DefaultNodeEmbeddingsStreamResult;
27+
import org.neo4j.gds.result.TimedAlgorithmResult;
28+
import org.neo4j.gds.results.ResultTransformer;
29+
30+
import java.util.stream.LongStream;
31+
import java.util.stream.Stream;
32+
33+
class FastRPStreamResultTransformer implements ResultTransformer<TimedAlgorithmResult<FastRPResult>, Stream<DefaultNodeEmbeddingsStreamResult>> {
34+
private final Graph graph;
35+
36+
FastRPStreamResultTransformer(Graph graph) {
37+
this.graph = graph;
38+
}
39+
40+
@Override
41+
public Stream<DefaultNodeEmbeddingsStreamResult> apply(TimedAlgorithmResult<FastRPResult> algorithmResult) {
42+
var fastRPResult = algorithmResult.result();
43+
var nodePropertyValues = NodePropertyValuesAdapter.adapt(fastRPResult.embeddings());
44+
return LongStream
45+
.range(IdMap.START_NODE_ID, nodePropertyValues.nodeCount())
46+
.filter(nodePropertyValues::hasValue)
47+
.mapToObj(nodeId -> DefaultNodeEmbeddingsStreamResult.create(
48+
graph.toOriginalNodeId(nodeId),
49+
nodePropertyValues.floatArrayValue(nodeId)
50+
));
51+
}
52+
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.procedures.algorithms.embeddings.stream;
21+
22+
import org.neo4j.gds.api.Graph;
23+
import org.neo4j.gds.api.IdMap;
24+
import org.neo4j.gds.embeddings.hashgnn.HashGNNResult;
25+
import org.neo4j.gds.procedures.algorithms.embeddings.DefaultNodeEmbeddingsStreamResult;
26+
import org.neo4j.gds.result.TimedAlgorithmResult;
27+
import org.neo4j.gds.results.ResultTransformer;
28+
29+
import java.util.stream.LongStream;
30+
import java.util.stream.Stream;
31+
32+
class HashGNNStreamResultTransformer implements ResultTransformer<TimedAlgorithmResult<HashGNNResult>, Stream<DefaultNodeEmbeddingsStreamResult>> {
33+
private final Graph graph;
34+
35+
HashGNNStreamResultTransformer(Graph graph) {
36+
this.graph = graph;
37+
}
38+
39+
@Override
40+
public Stream<DefaultNodeEmbeddingsStreamResult> apply(TimedAlgorithmResult<HashGNNResult> algorithmResult) {
41+
var hashGNNResult = algorithmResult.result();
42+
return LongStream
43+
.range(IdMap.START_NODE_ID, graph.nodeCount())
44+
.mapToObj(i -> DefaultNodeEmbeddingsStreamResult.create(
45+
graph.toOriginalNodeId(i),
46+
hashGNNResult.embeddings().doubleArrayValue(i)
47+
));
48+
}
49+
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.procedures.algorithms.embeddings.stream;
21+
22+
import org.neo4j.gds.algorithms.embeddings.FloatEmbeddingNodePropertyValues;
23+
import org.neo4j.gds.api.Graph;
24+
import org.neo4j.gds.api.IdMap;
25+
import org.neo4j.gds.embeddings.node2vec.Node2VecResult;
26+
import org.neo4j.gds.procedures.algorithms.embeddings.DefaultNodeEmbeddingsStreamResult;
27+
import org.neo4j.gds.result.TimedAlgorithmResult;
28+
import org.neo4j.gds.results.ResultTransformer;
29+
30+
import java.util.stream.LongStream;
31+
import java.util.stream.Stream;
32+
33+
class Node2VecStreamResultTransformer implements ResultTransformer<TimedAlgorithmResult<Node2VecResult>, Stream<DefaultNodeEmbeddingsStreamResult>> {
34+
private final Graph graph;
35+
36+
Node2VecStreamResultTransformer(Graph graph) {
37+
this.graph = graph;
38+
}
39+
40+
@Override
41+
public Stream<DefaultNodeEmbeddingsStreamResult> apply(TimedAlgorithmResult<Node2VecResult> algorithmResult) {
42+
var node2VecResult = algorithmResult.result();
43+
var nodePropertyValues = new FloatEmbeddingNodePropertyValues(node2VecResult.embeddings());
44+
45+
return LongStream
46+
.range(IdMap.START_NODE_ID, graph.nodeCount())
47+
.filter(nodePropertyValues::hasValue)
48+
.mapToObj(nodeId -> DefaultNodeEmbeddingsStreamResult.create(
49+
graph.toOriginalNodeId(nodeId),
50+
nodePropertyValues.floatArrayValue(nodeId)
51+
));
52+
}
53+
}

0 commit comments

Comments
 (0)