From 4c96391972181927ada1516db916f6efb8140881 Mon Sep 17 00:00:00 2001 From: Michael Hunger Date: Thu, 13 Sep 2018 13:21:26 +0200 Subject: [PATCH 1/2] Added Euclidean and Consine Distance Procedures * refactored all similarity procedures to use the same base mechanisms and datastructures with passed in delegates for the actual work * added tests * todo documentation for cosine and euclidean --- .../java/org/neo4j/graphalgo/JaccardProc.java | 360 ---------------- .../graphalgo/impl/util/TopKConsumer.java | 27 +- .../impl/yens/SimilarityExporter.java | 7 +- .../similarity/CategoricalInput.java | 29 ++ .../graphalgo/similarity/CosineProc.java | 88 ++++ .../graphalgo/similarity/EuclideanProc.java | 87 ++++ .../graphalgo/similarity/JaccardProc.java | 68 ++++ .../Similarities.java} | 4 +- .../similarity/SimilarityConsumer.java | 5 + .../graphalgo/similarity/SimilarityProc.java | 284 +++++++++++++ .../{ => similarity}/SimilarityResult.java | 53 +-- .../SimilaritySummaryResult.java | 28 +- .../graphalgo/similarity/WeightedInput.java | 49 +++ .../graphalgo/impl/util/TopKConsumerTest.java | 16 +- .../bench/CosineSimilarityBenchmark.java | 85 ++++ .../bench/SquareDeltasBenchmark.java | 82 ++++ .../graphalgo/core/utils/Intersections.java | 96 +++++ .../scripts/similarity-jaccard.cypher | 20 +- doc/asciidoc/similarity-jaccard.adoc | 8 +- .../org/neo4j/graphalgo/algo/JaccardTest.java | 385 ------------------ .../graphalgo/algo/similarity/CosineTest.java | 364 +++++++++++++++++ .../algo/similarity/EuclideanTest.java | 381 +++++++++++++++++ .../algo/similarity/JaccardTest.java | 313 ++++++++++++++ .../SimilaritiesTest.java} | 8 +- 24 files changed, 2034 insertions(+), 813 deletions(-) delete mode 100644 algo/src/main/java/org/neo4j/graphalgo/JaccardProc.java create mode 100644 algo/src/main/java/org/neo4j/graphalgo/similarity/CategoricalInput.java create mode 100644 algo/src/main/java/org/neo4j/graphalgo/similarity/CosineProc.java create mode 100644 algo/src/main/java/org/neo4j/graphalgo/similarity/EuclideanProc.java create mode 100644 algo/src/main/java/org/neo4j/graphalgo/similarity/JaccardProc.java rename algo/src/main/java/org/neo4j/graphalgo/{Similarity.java => similarity/Similarities.java} (98%) create mode 100644 algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityConsumer.java create mode 100644 algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityProc.java rename algo/src/main/java/org/neo4j/graphalgo/{ => similarity}/SimilarityResult.java (60%) rename algo/src/main/java/org/neo4j/graphalgo/{ => similarity}/SimilaritySummaryResult.java (67%) create mode 100644 algo/src/main/java/org/neo4j/graphalgo/similarity/WeightedInput.java create mode 100644 benchmark/src/main/java/org/neo4j/graphalgo/bench/CosineSimilarityBenchmark.java create mode 100644 benchmark/src/main/java/org/neo4j/graphalgo/bench/SquareDeltasBenchmark.java delete mode 100644 tests/src/test/java/org/neo4j/graphalgo/algo/JaccardTest.java create mode 100644 tests/src/test/java/org/neo4j/graphalgo/algo/similarity/CosineTest.java create mode 100644 tests/src/test/java/org/neo4j/graphalgo/algo/similarity/EuclideanTest.java create mode 100644 tests/src/test/java/org/neo4j/graphalgo/algo/similarity/JaccardTest.java rename tests/src/test/java/org/neo4j/graphalgo/algo/{SimilarityTest.java => similarity/SimilaritiesTest.java} (98%) diff --git a/algo/src/main/java/org/neo4j/graphalgo/JaccardProc.java b/algo/src/main/java/org/neo4j/graphalgo/JaccardProc.java deleted file mode 100644 index 5a2a5ef34..000000000 --- a/algo/src/main/java/org/neo4j/graphalgo/JaccardProc.java +++ /dev/null @@ -1,360 +0,0 @@ -/** - * Copyright (c) 2017 "Neo4j, Inc." - * - * This file is part of Neo4j Graph Algorithms . - * - * Neo4j Graph Algorithms is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.graphalgo; - -import com.carrotsearch.hppc.LongHashSet; -import org.HdrHistogram.DoubleHistogram; -import org.neo4j.graphalgo.core.ProcedureConfiguration; -import org.neo4j.graphalgo.core.utils.ParallelUtil; -import org.neo4j.graphalgo.core.utils.Pools; -import org.neo4j.graphalgo.core.utils.QueueBasedSpliterator; -import org.neo4j.graphalgo.core.utils.TerminationFlag; -import org.neo4j.graphalgo.impl.util.TopKConsumer; -import org.neo4j.graphalgo.impl.yens.SimilarityExporter; -import org.neo4j.kernel.api.KernelTransaction; -import org.neo4j.kernel.internal.GraphDatabaseAPI; -import org.neo4j.logging.Log; -import org.neo4j.procedure.*; - -import java.util.*; -import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.atomic.AtomicLong; -import java.util.function.Consumer; -import java.util.stream.Collectors; -import java.util.stream.IntStream; -import java.util.stream.Stream; -import java.util.stream.StreamSupport; - -import static org.neo4j.graphalgo.impl.util.TopKConsumer.topK; - -public class JaccardProc { - - @Context - public GraphDatabaseAPI api; - - @Context - public Log log; - - @Context - public KernelTransaction transaction; - - - @Procedure(name = "algo.similarity.jaccard.stream", mode = Mode.READ) - @Description("CALL algo.similarity.jaccard.stream([{source:id, targets:[ids]}], {similarityCutoff:-1,degreeCutoff:0}) " + - "YIELD source1, source2, count1, count2, intersection, jaccard - computes jaccard similarities") - public Stream jaccardStream( - @Name(value = "data", defaultValue = "null") List> data, - @Name(value = "config", defaultValue = "{}") Map config) { - ProcedureConfiguration configuration = ProcedureConfiguration.create(config); - double similarityCutoff = configuration.get("similarityCutoff", -1D); - long degreeCutoff = configuration.get("degreeCutoff", 0L); - - InputData[] ids = fillIds(data, degreeCutoff); - int length = ids.length; - int topN = configuration.getInt("top",0); - int topK = configuration.getInt("topK",0); - TerminationFlag terminationFlag = TerminationFlag.wrap(transaction); - - int concurrency = configuration.getConcurrency(); - - return jaccardStreamMe(ids, length, terminationFlag, concurrency, similarityCutoff, topN, topK); - } - - @Procedure(name = "algo.similarity.jaccard", mode = Mode.WRITE) - @Description("CALL algo.similarity.jaccard([{source:id, targets:[ids]}], {similarityCutoff:-1,degreeCutoff:0}) " + - "YIELD p50, p75, p90, p99, p999, p100 - computes jaccard similarities") - public Stream jaccard( - @Name(value = "data", defaultValue = "null") List> data, - @Name(value = "config", defaultValue = "{}") Map config) { - ProcedureConfiguration configuration = ProcedureConfiguration.create(config); - - double similarityCutoff = configuration.get("similarityCutoff", -1D); - long degreeCutoff = configuration.get("degreeCutoff", 0L); - - InputData[] ids = fillIds(data, degreeCutoff); - long length = ids.length; - TerminationFlag terminationFlag = TerminationFlag.wrap(transaction); - int concurrency = configuration.getConcurrency(); - int topN = configuration.getInt("top",0); - int topK = configuration.getInt("topK",0); - - DoubleHistogram histogram = new DoubleHistogram(5); - AtomicLong similarityPairs = new AtomicLong(); - Stream stream = jaccardStreamMe(ids, (int) length, terminationFlag, concurrency, similarityCutoff, topN, topK); - - String writeRelationshipType = configuration.get("writeRelationshipType", "SIMILAR"); - String writeProperty = configuration.getWriteProperty("score"); - boolean write = configuration.isWriteFlag(false) && similarityCutoff > 0.0; - if(write) { - SimilarityExporter similarityExporter = new SimilarityExporter(api, writeRelationshipType, writeProperty); - - Stream similarities = stream.peek(recordInHistogram(histogram, similarityPairs)); - similarityExporter.export(similarities); - - } else { - stream.forEach(recordInHistogram(histogram, similarityPairs)); - } - - SimilaritySummaryResult result = new SimilaritySummaryResult( - length, - similarityPairs.get(), - write, - writeRelationshipType, - writeProperty, - histogram.getMinValue(), - histogram.getMaxValue(), - histogram.getMean(), - histogram.getStdDeviation(), - histogram.getValueAtPercentile(25D), - histogram.getValueAtPercentile(50D), - histogram.getValueAtPercentile(75D), - histogram.getValueAtPercentile(90D), - histogram.getValueAtPercentile(95D), - histogram.getValueAtPercentile(99D), - histogram.getValueAtPercentile(99.9D), - histogram.getValueAtPercentile(100D) - ); - - - - - return Stream.of(result); - } - - private Consumer recordInHistogram(DoubleHistogram histogram, AtomicLong similarityPairs) { - return result -> { - try { - histogram.recordValue(result.similarity); - } catch (ArrayIndexOutOfBoundsException ignored) { - - } - similarityPairs.getAndIncrement(); - }; - } - - private Stream jaccardStreamMe(InputData[] ids, int length, TerminationFlag terminationFlag, int concurrency, double similarityCutoff, int topN, int topK) { - if (concurrency == 1) { - if (topK > 0) { - return jaccardStreamTopK(ids, length, similarityCutoff, topN, topK); - } else { - return jaccardStream(ids, length, similarityCutoff, topN); - } - } else { - if (topK > 0) { - return jaccardParallelStreamTopK(ids, length, terminationFlag, concurrency, similarityCutoff, topN, topK); - } else { - return jaccardParallelStream(ids, length, terminationFlag, concurrency, similarityCutoff, topN); - } - } - } - - private Stream jaccardStream(InputData[] ids, int length, double similiarityCutoff, int topN) { - Stream stream = IntStream.range(0, length) - .boxed().flatMap(sourceId -> IntStream.range(sourceId + 1, length) - .mapToObj(targetId -> calculateJaccard(similiarityCutoff, ids[sourceId], ids[targetId])).filter(Objects::nonNull)); - return topN(stream,topN); - } - - private Stream jaccardStreamTopK(InputData[] ids, int length, double similiarityCutoff, int topN, int topK) { - TopKConsumer[] topKHolder = initializeTopKConsumers(length, topK); - - for (int sourceId = 0;sourceId < length;sourceId++) { - computeJaccardForSourceIndex(sourceId, ids, length, similiarityCutoff, (sourceIndex, targetIndex, similarityResult) -> { - topKHolder[sourceIndex].accept(similarityResult); - topKHolder[targetIndex].accept(similarityResult.reverse()); - }); - } - return topN(Arrays.stream(topKHolder).flatMap(TopKConsumer::stream),topN); - } - - interface SimilarityConsumer { - void accept(int sourceIndex, int targetIndex, SimilarityResult result); - } - - private TopKConsumer[] initializeTopKConsumers(int length, int topK) { - TopKConsumer[] results = new TopKConsumer[length]; - for (int i = 0; i < results.length; i++) results[i] = new TopKConsumer<>(topK); - return results; - } - - private Stream jaccardParallelStream(InputData[] ids, int length, TerminationFlag terminationFlag, int concurrency, double similiarityCutoff, int topN) { - - int timeout = 100; - int queueSize = 1000; - - int batchSize = ParallelUtil.adjustBatchSize(length, concurrency, 1); - int taskCount = (length / batchSize) + 1; - Collection tasks = new ArrayList<>(taskCount); - - ArrayBlockingQueue queue = new ArrayBlockingQueue<>(queueSize); - - int multiplier = batchSize < length ? batchSize : 1; - for (int taskId = 0; taskId < taskCount; taskId++) { - int taskOffset = taskId; - tasks.add(() -> { - for (int offset = 0; offset < batchSize; offset++) { - int sourceId = taskOffset * multiplier + offset; - if (sourceId < length) - computeJaccardForSourceIndex(sourceId, ids, length, similiarityCutoff, (s, t, result) -> put(queue, result)); - } - }); - } - - new Thread(() -> { - try { - ParallelUtil.runWithConcurrency(concurrency, tasks, terminationFlag, Pools.DEFAULT); - } finally { - put(queue, SimilarityResult.TOMB); - } - }).start(); - - QueueBasedSpliterator spliterator = new QueueBasedSpliterator<>(queue, SimilarityResult.TOMB, terminationFlag, timeout); - Stream stream = StreamSupport.stream(spliterator, false); - return topN(stream, topN); - } - - - private Stream jaccardParallelStreamTopK(InputData[] ids, int length, TerminationFlag terminationFlag, int concurrency, double similiarityCutoff, int topN, int topK) { - int batchSize = ParallelUtil.adjustBatchSize(length, concurrency, 1); - int taskCount = (length / batchSize) + (length % batchSize > 0 ? 1 : 0); - Collection tasks = new ArrayList<>(taskCount); - - int multiplier = batchSize < length ? batchSize : 1; - for (int taskId = 0; taskId < taskCount; taskId++) { - tasks.add(new TopKTask(batchSize, taskId, multiplier, length, ids, similiarityCutoff, topK)); - } - - ParallelUtil.runWithConcurrency(concurrency, tasks, terminationFlag, Pools.DEFAULT); - - TopKConsumer[] topKConsumers = initializeTopKConsumers(length, topK); - for (Runnable task : tasks) ((TopKTask)task).mergeInto(topKConsumers); - Stream stream = Arrays.stream(topKConsumers).flatMap(TopKConsumer::stream); - return topN(stream, topN); - } - - private void computeJaccardForSourceIndex(int sourceId, InputData[] ids, int length, double similiarityCutoff, SimilarityConsumer consumer) { - for (int targetId=sourceId+1;targetId topN(Stream stream, int topN) { - if (topN <= 0) { - return stream; - } - if (topN > 10000) { - return stream.sorted().limit(topN); - } - return topK(stream,topN); - } - - private SimilarityResult calculateJaccard(double similarityCutoff, InputData e1, InputData e2) { - return SimilarityResult.of(e1.id, e2.id, e1.targets, e2.targets, similarityCutoff); - } - - private static void put(BlockingQueue queue, T items) { - try { - queue.put(items); - } catch (InterruptedException e) { - // ignore - } - } - - private static class InputData implements Comparable { - long id; - long[] targets; - - public InputData(long id, long[] targets) { - this.id = id; - this.targets = targets; - } - - @Override - public int compareTo(InputData o) { - return Long.compare(id, o.id); - } - } - - private InputData[] fillIds(@Name(value = "data", defaultValue = "null") List> data, long degreeCutoff) { - InputData[] ids = new InputData[data.size()]; - int idx = 0; - for (Map row : data) { - List targetIds = (List) row.get("targets"); - int size = targetIds.size(); - if ( size > degreeCutoff) { - long[] targets = new long[size]; - int i=0; - for (Long id : targetIds) { - targets[i++]=id; - } - Arrays.sort(targets); - ids[idx++] = new InputData((Long) row.get("source"), targets); - } - } - if (idx != ids.length) ids = Arrays.copyOf(ids, idx); - Arrays.sort(ids); - return ids; - } - - private class TopKTask implements Runnable { - private final int batchSize; - private final int taskOffset; - private final int multiplier; - private final int length; - private final InputData[] ids; - private final double similiarityCutoff; - private final TopKConsumer[] topKConsumers; - - public TopKTask(int batchSize, int taskOffset, int multiplier, int length, InputData[] ids, double similiarityCutoff, int topK) { - this.batchSize = batchSize; - this.taskOffset = taskOffset; - this.multiplier = multiplier; - this.length = length; - this.ids = ids; - this.similiarityCutoff = similiarityCutoff; - topKConsumers = initializeTopKConsumers(length, topK); - } - - @Override - public void run() { - for (int offset = 0; offset < batchSize; offset++) { - int sourceId = taskOffset * multiplier + offset; - if (sourceId < length) { - JaccardProc.this.computeJaccardForSourceIndex(sourceId, ids, length, similiarityCutoff, (s, t, result) -> { - topKConsumers[s].accept(result); - topKConsumers[t].accept(result.reverse()); - }); - } - } - } - public void mergeInto(TopKConsumer[] target) { - for (int i = 0; i < target.length; i++) { - target[i].accept(topKConsumers[i]); - } - } - } - - - // roaring bitset - // test with JMH -} diff --git a/algo/src/main/java/org/neo4j/graphalgo/impl/util/TopKConsumer.java b/algo/src/main/java/org/neo4j/graphalgo/impl/util/TopKConsumer.java index 460e70f1e..83cc4b42f 100644 --- a/algo/src/main/java/org/neo4j/graphalgo/impl/util/TopKConsumer.java +++ b/algo/src/main/java/org/neo4j/graphalgo/impl/util/TopKConsumer.java @@ -1,39 +1,42 @@ package org.neo4j.graphalgo.impl.util; import java.util.Arrays; +import java.util.Comparator; import java.util.List; import java.util.function.Consumer; -import java.util.stream.Collectors; import java.util.stream.Stream; -public class TopKConsumer> implements Consumer { +public class TopKConsumer implements Consumer { private final int topK; private final T[] heap; + private Comparator comparator; private int count; private T minValue; - public TopKConsumer(int topK) { + public TopKConsumer(int topK, Comparator comparator) { this.topK = topK; - heap = (T[]) new Comparable[topK]; + heap = (T[]) new Object[topK]; + this.comparator = comparator; count = 0; minValue = null; } - public static > List topK(List items, int topK) { - TopKConsumer consumer = new TopKConsumer<>(topK); - items.stream().forEach(consumer); + public static List topK(List items, int topK, Comparator comparator) { + TopKConsumer consumer = new TopKConsumer<>(topK, comparator); + items.forEach(consumer); return consumer.list(); } - public static > Stream topK(Stream items, int topK) { - TopKConsumer consumer = new TopKConsumer<>(topK); + + public static Stream topK(Stream items, int topK, Comparator comparator) { + TopKConsumer consumer = new TopKConsumer(topK, comparator); items.forEach(consumer); return consumer.stream(); } @Override public void accept(T item) { - if (count < topK || minValue == null || item.compareTo(minValue) < 0) { - int idx = Arrays.binarySearch(heap, 0, count, item); + if (count < topK || minValue == null || comparator.compare(item,minValue) < 0) { + int idx = Arrays.binarySearch(heap, 0, count, item, comparator); idx = (idx < 0) ? -idx : idx + 1; int length = topK - idx; if (length > 0 && idx < topK) System.arraycopy(heap,idx-1,heap,idx, length); @@ -53,7 +56,7 @@ public List list() { } public void accept(TopKConsumer other) { - if (minValue == null || count < topK || other.minValue != null && other.minValue.compareTo(minValue) < 0) { + if (minValue == null || count < topK || other.minValue != null && comparator.compare(other.minValue,minValue) < 0) { for (int i=0;i similarityPairs) { private void export(SimilarityResult similarityResult) { applyInTransaction(statement -> { - long node1 = similarityResult.source1; - long node2 = similarityResult.source2; + long node1 = similarityResult.item1; + long node2 = similarityResult.item2; try { long relationshipId = statement.dataWrite().relationshipCreate(node1, relationshipTypeId, node2); diff --git a/algo/src/main/java/org/neo4j/graphalgo/similarity/CategoricalInput.java b/algo/src/main/java/org/neo4j/graphalgo/similarity/CategoricalInput.java new file mode 100644 index 000000000..761f019e9 --- /dev/null +++ b/algo/src/main/java/org/neo4j/graphalgo/similarity/CategoricalInput.java @@ -0,0 +1,29 @@ +package org.neo4j.graphalgo.similarity; + +import org.neo4j.graphalgo.core.utils.Intersections; + +class CategoricalInput implements Comparable { + long id; + long[] targets; + + public CategoricalInput(long id, long[] targets) { + this.id = id; + this.targets = targets; + } + + @Override + public int compareTo(CategoricalInput o) { + return Long.compare(id, o.id); + } + + SimilarityResult jaccard(double similarityCutoff, CategoricalInput e2) { + long intersection = Intersections.intersection3(targets, e2.targets); + if (similarityCutoff >= 0d && intersection == 0) return null; + int count1 = targets.length; + int count2 = e2.targets.length; + long denominator = count1 + count2 - intersection; + double jaccard = denominator == 0 ? 0 : (double)intersection / denominator; + if (jaccard < similarityCutoff) return null; + return new SimilarityResult(id, e2.id, count1, count2, intersection, jaccard); + } +} diff --git a/algo/src/main/java/org/neo4j/graphalgo/similarity/CosineProc.java b/algo/src/main/java/org/neo4j/graphalgo/similarity/CosineProc.java new file mode 100644 index 000000000..67e11ef3e --- /dev/null +++ b/algo/src/main/java/org/neo4j/graphalgo/similarity/CosineProc.java @@ -0,0 +1,88 @@ +/** + * Copyright (c) 2017 "Neo4j, Inc." + * + * This file is part of Neo4j Graph Algorithms . + * + * Neo4j Graph Algorithms is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.graphalgo.similarity; + +import org.neo4j.graphalgo.core.ProcedureConfiguration; +import org.neo4j.procedure.Description; +import org.neo4j.procedure.Mode; +import org.neo4j.procedure.Name; +import org.neo4j.procedure.Procedure; + +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +public class CosineProc extends SimilarityProc { + + @Procedure(name = "algo.similarity.cosine.stream", mode = Mode.READ) + @Description("CALL algo.similarity.cosine.stream([{source:id, weights:[weights]}], {similarityCutoff:-1,degreeCutoff:0}) " + + "YIELD item1, item2, count1, count2, intersection, similarity - computes cosine distance") + // todo count1,count2 = could be the non-null values, intersection the values where both are non-null? + public Stream cosineStream( + @Name(value = "data", defaultValue = "null") List> data, + @Name(value = "config", defaultValue = "{}") Map config) { + + SimilarityComputer computer = (s,t,cutoff) -> s.cosineSquares(cutoff, t); + + ProcedureConfiguration configuration = ProcedureConfiguration.create(config); + + WeightedInput[] inputs = prepareWeights(data, getDegreeCutoff(configuration)); + + double similarityCutoff = getSimilarityCutoff(configuration); + // as we don't compute the sqrt until the end + if (similarityCutoff > 0d) similarityCutoff *= similarityCutoff; + + int topN = getTopN(configuration); + int topK = getTopK(configuration); + + Stream stream = topN(similarityStream(inputs, computer, configuration, similarityCutoff, topK), topN); + + return stream.map(SimilarityResult::squareRooted); + } + + @Procedure(name = "algo.similarity.cosine", mode = Mode.WRITE) + @Description("CALL algo.similarity.cosine([{item:id, weights:[weights]}], {similarityCutoff:-1,degreeCutoff:0}) " + + "YIELD p50, p75, p90, p99, p999, p100 - computes cosine similarities") + public Stream cosine( + @Name(value = "data", defaultValue = "null") List> data, + @Name(value = "config", defaultValue = "{}") Map config) { + + SimilarityComputer computer = (s,t,cutoff) -> s.cosineSquares(cutoff, t); + + ProcedureConfiguration configuration = ProcedureConfiguration.create(config); + + WeightedInput[] inputs = prepareWeights(data, getDegreeCutoff(configuration)); + + double similarityCutoff = getSimilarityCutoff(configuration); + // as we don't compute the sqrt until the end + if (similarityCutoff > 0d) similarityCutoff *= similarityCutoff; + + int topN = getTopN(configuration); + int topK = getTopK(configuration); + + Stream stream = topN(similarityStream(inputs, computer, configuration, similarityCutoff, topK), topN) + .map(SimilarityResult::squareRooted); + + + boolean write = configuration.isWriteFlag(false) && similarityCutoff > 0.0; + return writeAndAggregateResults(configuration, stream, inputs.length, write); + } + + +} diff --git a/algo/src/main/java/org/neo4j/graphalgo/similarity/EuclideanProc.java b/algo/src/main/java/org/neo4j/graphalgo/similarity/EuclideanProc.java new file mode 100644 index 000000000..0a0c0a765 --- /dev/null +++ b/algo/src/main/java/org/neo4j/graphalgo/similarity/EuclideanProc.java @@ -0,0 +1,87 @@ +/** + * Copyright (c) 2017 "Neo4j, Inc." + * + * This file is part of Neo4j Graph Algorithms . + * + * Neo4j Graph Algorithms is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.graphalgo.similarity; + +import org.neo4j.graphalgo.core.ProcedureConfiguration; +import org.neo4j.procedure.Description; +import org.neo4j.procedure.Mode; +import org.neo4j.procedure.Name; +import org.neo4j.procedure.Procedure; + +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +public class EuclideanProc extends SimilarityProc { + + @Procedure(name = "algo.similarity.euclidean.stream", mode = Mode.READ) + @Description("CALL algo.similarity.euclidean.stream([{source:id, weights:[weights]}], {similarityCutoff:-1,degreeCutoff:0}) " + + "YIELD item1, item2, count1, count2, intersection, similarity - computes euclidean distance") + // todo count1,count2 = could be the non-null values, intersection the values where both are non-null? + public Stream euclideanStream( + @Name(value = "data", defaultValue = "null") List> data, + @Name(value = "config", defaultValue = "{}") Map config) { + + SimilarityComputer computer = (s,t,cutoff) -> s.sumSquareDelta(cutoff, t); + + ProcedureConfiguration configuration = ProcedureConfiguration.create(config); + + WeightedInput[] inputs = prepareWeights(data, getDegreeCutoff(configuration)); + + double similarityCutoff = getSimilarityCutoff(configuration); + // as we don't compute the sqrt until the end + if (similarityCutoff > 0d) similarityCutoff *= similarityCutoff; + + int topN = -getTopN(configuration); + int topK = -getTopK(configuration); + + Stream stream = topN(similarityStream(inputs, computer, configuration, similarityCutoff, topK), topN); + + return stream.map(SimilarityResult::squareRooted); + } + + @Procedure(name = "algo.similarity.euclidean", mode = Mode.WRITE) + @Description("CALL algo.similarity.euclidean([{item:id, weights:[weights]}], {similarityCutoff:-1,degreeCutoff:0}) " + + "YIELD p50, p75, p90, p99, p999, p100 - computes euclidean similarities") + public Stream euclidean( + @Name(value = "data", defaultValue = "null") List> data, + @Name(value = "config", defaultValue = "{}") Map config) { + + SimilarityComputer computer = (s,t,cutoff) -> s.sumSquareDelta(cutoff, t); + + ProcedureConfiguration configuration = ProcedureConfiguration.create(config); + + WeightedInput[] inputs = prepareWeights(data, getDegreeCutoff(configuration)); + + double similarityCutoff = getSimilarityCutoff(configuration); + // as we don't compute the sqrt until the end + if (similarityCutoff > 0d) similarityCutoff *= similarityCutoff; + + int topN = -getTopN(configuration); + int topK = -getTopK(configuration); + + Stream stream = topN(similarityStream(inputs, computer, configuration, similarityCutoff, topK), topN) + .map(SimilarityResult::squareRooted); + + boolean write = configuration.isWriteFlag(false); // && similarityCutoff != 0.0; + return writeAndAggregateResults(configuration, stream, inputs.length, write); + } + + +} diff --git a/algo/src/main/java/org/neo4j/graphalgo/similarity/JaccardProc.java b/algo/src/main/java/org/neo4j/graphalgo/similarity/JaccardProc.java new file mode 100644 index 000000000..594d8fb6b --- /dev/null +++ b/algo/src/main/java/org/neo4j/graphalgo/similarity/JaccardProc.java @@ -0,0 +1,68 @@ +/** + * Copyright (c) 2017 "Neo4j, Inc." + * + * This file is part of Neo4j Graph Algorithms . + * + * Neo4j Graph Algorithms is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.graphalgo.similarity; + +import org.neo4j.graphalgo.core.ProcedureConfiguration; +import org.neo4j.procedure.*; + +import java.util.*; +import java.util.stream.Stream; + +import static org.neo4j.graphalgo.impl.util.TopKConsumer.topK; + +public class JaccardProc extends SimilarityProc { + + @Procedure(name = "algo.similarity.jaccard.stream", mode = Mode.READ) + @Description("CALL algo.similarity.jaccard.stream([{source:id, targets:[ids]}], {similarityCutoff:-1,degreeCutoff:0}) " + + "YIELD item1, item2, count1, count2, intersection, similarity - computes jaccard similarities") + public Stream similarityStream( + @Name(value = "data", defaultValue = "null") List> data, + @Name(value = "config", defaultValue = "{}") Map config) { + + SimilarityComputer computer = (s, t, cutoff) -> s.jaccard(cutoff, t); + + ProcedureConfiguration configuration = ProcedureConfiguration.create(config); + + CategoricalInput[] inputs = prepareCategories(data, getDegreeCutoff(configuration)); + + return topN(similarityStream(inputs, computer, configuration, getSimilarityCutoff(configuration), getTopK(configuration)), getTopN(configuration)); + } + + @Procedure(name = "algo.similarity.jaccard", mode = Mode.WRITE) + @Description("CALL algo.similarity.jaccard([{source:id, targets:[ids]}], {similarityCutoff:-1,degreeCutoff:0}) " + + "YIELD p50, p75, p90, p99, p999, p100 - computes jaccard similarities") + public Stream jaccard( + @Name(value = "data", defaultValue = "null") List> data, + @Name(value = "config", defaultValue = "{}") Map config) { + + SimilarityComputer computer = (s,t,cutoff) -> s.jaccard(cutoff, t); + + ProcedureConfiguration configuration = ProcedureConfiguration.create(config); + + CategoricalInput[] inputs = prepareCategories(data, getDegreeCutoff(configuration)); + + double similarityCutoff = getSimilarityCutoff(configuration); + Stream stream = topN(similarityStream(inputs, computer, configuration, similarityCutoff, getTopK(configuration)), getTopN(configuration)); + + boolean write = configuration.isWriteFlag(false) && similarityCutoff > 0.0; + return writeAndAggregateResults(configuration, stream, inputs.length, write); + } + + +} diff --git a/algo/src/main/java/org/neo4j/graphalgo/Similarity.java b/algo/src/main/java/org/neo4j/graphalgo/similarity/Similarities.java similarity index 98% rename from algo/src/main/java/org/neo4j/graphalgo/Similarity.java rename to algo/src/main/java/org/neo4j/graphalgo/similarity/Similarities.java index 2a01e2065..b8e7ef136 100644 --- a/algo/src/main/java/org/neo4j/graphalgo/Similarity.java +++ b/algo/src/main/java/org/neo4j/graphalgo/similarity/Similarities.java @@ -16,7 +16,7 @@ * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ -package org.neo4j.graphalgo; +package org.neo4j.graphalgo.similarity; import org.neo4j.procedure.Description; import org.neo4j.procedure.Name; @@ -25,7 +25,7 @@ import java.util.HashSet; import java.util.List; -public class Similarity { +public class Similarities { @UserFunction("algo.similarity.jaccard") @Description("algo.similarity.jaccard([vector1], [vector2]) " + diff --git a/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityConsumer.java b/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityConsumer.java new file mode 100644 index 000000000..4861e5d90 --- /dev/null +++ b/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityConsumer.java @@ -0,0 +1,5 @@ +package org.neo4j.graphalgo.similarity; + +interface SimilarityConsumer { + void accept(int sourceIndex, int targetIndex, SimilarityResult result); +} diff --git a/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityProc.java b/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityProc.java new file mode 100644 index 000000000..487158f1a --- /dev/null +++ b/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityProc.java @@ -0,0 +1,284 @@ +package org.neo4j.graphalgo.similarity; + +import org.HdrHistogram.DoubleHistogram; +import org.neo4j.graphalgo.core.ProcedureConfiguration; +import org.neo4j.graphalgo.core.utils.ParallelUtil; +import org.neo4j.graphalgo.core.utils.Pools; +import org.neo4j.graphalgo.core.utils.QueueBasedSpliterator; +import org.neo4j.graphalgo.core.utils.TerminationFlag; +import org.neo4j.graphalgo.impl.util.TopKConsumer; +import org.neo4j.graphalgo.impl.yens.SimilarityExporter; +import org.neo4j.kernel.api.KernelTransaction; +import org.neo4j.kernel.internal.GraphDatabaseAPI; +import org.neo4j.logging.Log; +import org.neo4j.procedure.*; + +import java.util.*; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; + +import static org.neo4j.graphalgo.impl.util.TopKConsumer.topK; + +public class SimilarityProc { + @Context + public GraphDatabaseAPI api; + @Context + public Log log; + @Context + public KernelTransaction transaction; + + private static TopKConsumer[] initializeTopKConsumers(int length, int topK) { + Comparator comparator = topK > 0 ? SimilarityResult.DESCENDING : SimilarityResult.ASCENDING; + topK = Math.abs(topK); + + TopKConsumer[] results = new TopKConsumer[length]; + for (int i = 0; i < results.length; i++) results[i] = new TopKConsumer<>(topK, comparator); + return results; + } + + static Stream topN(Stream stream, int topN) { + if (topN == 0) { + return stream; + } + Comparator comparator = topN > 0 ? SimilarityResult.DESCENDING : SimilarityResult.ASCENDING; + topN = Math.abs(topN); + + if (topN > 10000) { + return stream.sorted(comparator).limit(topN); + } + return topK(stream,topN, comparator); + } + + private static void put(BlockingQueue queue, T items) { + try { + queue.put(items); + } catch (InterruptedException e) { + // ignore + } + } + + Long getDegreeCutoff(ProcedureConfiguration configuration) { + return configuration.get("degreeCutoff", 0L); + } + + Stream writeAndAggregateResults(ProcedureConfiguration configuration, Stream stream, int length, boolean write) { + String writeRelationshipType = configuration.get("writeRelationshipType", "SIMILAR"); + String writeProperty = configuration.getWriteProperty("score"); + + AtomicLong similarityPairs = new AtomicLong(); + DoubleHistogram histogram = new DoubleHistogram(5); + Consumer recorder = result -> { + result.record(histogram); + similarityPairs.getAndIncrement(); + }; + + if(write) { + SimilarityExporter similarityExporter = new SimilarityExporter(api, writeRelationshipType, writeProperty); + similarityExporter.export(stream.peek(recorder)); + } else { + stream.forEach(recorder); + } + + return Stream.of(SimilaritySummaryResult.from(length, similarityPairs, writeRelationshipType, writeProperty, write, histogram)); + } + + Double getSimilarityCutoff(ProcedureConfiguration configuration) { + return configuration.get("similarityCutoff", -1D); + } + + Stream similarityStream(T[] inputs, SimilarityComputer computer, ProcedureConfiguration configuration, double cutoff, int topK) { + TerminationFlag terminationFlag = TerminationFlag.wrap(transaction); + int concurrency = configuration.getConcurrency(); + + int length = inputs.length; + if (concurrency == 1) { + if (topK != 0) { + return similarityStreamTopK(inputs, length, cutoff, topK, computer); + } else { + return similarityStream(inputs, length, cutoff, computer); + } + } else { + if (topK != 0) { + return similarityParallelStreamTopK(inputs, length, terminationFlag, concurrency, cutoff, topK, computer); + } else { + return similarityParallelStream(inputs, length, terminationFlag, concurrency, cutoff, computer); + } + } + } + + private Stream similarityStream(T[] inputs, int length, double similiarityCutoff, SimilarityComputer computer) { + return IntStream.range(0, length) + .boxed().flatMap(sourceId -> IntStream.range(sourceId + 1, length) + .mapToObj(targetId -> computer.similarity(inputs[sourceId],inputs[targetId],similiarityCutoff)).filter(Objects::nonNull)); + } + + private Stream similarityStreamTopK(T[] inputs, int length, double cutoff, int topK, SimilarityComputer computer) { + TopKConsumer[] topKHolder = initializeTopKConsumers(length, topK); + + for (int sourceId = 0;sourceId < length;sourceId++) { + computeSimilarityForSourceIndex(sourceId, inputs, length, cutoff, (sourceIndex, targetIndex, similarityResult) -> { + topKHolder[sourceIndex].accept(similarityResult); + topKHolder[targetIndex].accept(similarityResult.reverse()); + }, computer); + } + return Arrays.stream(topKHolder).flatMap(TopKConsumer::stream); + } + + private Stream similarityParallelStream(T[] inputs, int length, TerminationFlag terminationFlag, int concurrency, double cutoff, SimilarityComputer computer) { + + int timeout = 100; + int queueSize = 1000; + + int batchSize = ParallelUtil.adjustBatchSize(length, concurrency, 1); + int taskCount = (length / batchSize) + (length % batchSize > 0 ? 1 : 0); + Collection tasks = new ArrayList<>(taskCount); + + ArrayBlockingQueue queue = new ArrayBlockingQueue<>(queueSize); + + int multiplier = batchSize < length ? batchSize : 1; + for (int taskId = 0; taskId < taskCount; taskId++) { + int taskOffset = taskId; + tasks.add(() -> { + for (int offset = 0; offset < batchSize; offset++) { + int sourceId = taskOffset * multiplier + offset; + if (sourceId < length) + computeSimilarityForSourceIndex(sourceId, inputs, length, cutoff, (s, t, result) -> put(queue, result), computer); + } + }); + } + + new Thread(() -> { + try { + ParallelUtil.runWithConcurrency(concurrency, tasks, terminationFlag, Pools.DEFAULT); + } finally { + put(queue, SimilarityResult.TOMB); + } + }).start(); + + QueueBasedSpliterator spliterator = new QueueBasedSpliterator<>(queue, SimilarityResult.TOMB, terminationFlag, timeout); + return StreamSupport.stream(spliterator, false); + } + + private Stream similarityParallelStreamTopK(T[] inputs, int length, TerminationFlag terminationFlag, int concurrency, double cutoff, int topK, SimilarityComputer computer) { + int batchSize = ParallelUtil.adjustBatchSize(length, concurrency, 1); + int taskCount = (length / batchSize) + (length % batchSize > 0 ? 1 : 0); + Collection tasks = new ArrayList<>(taskCount); + + int multiplier = batchSize < length ? batchSize : 1; + for (int taskId = 0; taskId < taskCount; taskId++) { + tasks.add(new TopKTask(batchSize, taskId, multiplier, length, inputs, cutoff, topK, computer)); + } + ParallelUtil.runWithConcurrency(concurrency, tasks, terminationFlag, Pools.DEFAULT); + + TopKConsumer[] topKConsumers = initializeTopKConsumers(length, topK); + for (Runnable task : tasks) ((TopKTask)task).mergeInto(topKConsumers); + return Arrays.stream(topKConsumers).flatMap(TopKConsumer::stream); + } + + private void computeSimilarityForSourceIndex(int sourceId, T[] inputs, int length, double cutoff, SimilarityConsumer consumer, SimilarityComputer computer) { + for (int targetId=sourceId+1;targetId> data, long degreeCutoff) { + CategoricalInput[] ids = new CategoricalInput[data.size()]; + int idx = 0; + for (Map row : data) { + List targetIds = (List) row.get("categories"); + int size = targetIds.size(); + if ( size > degreeCutoff) { + long[] targets = new long[size]; + int i=0; + for (Long id : targetIds) { + targets[i++]=id; + } + Arrays.sort(targets); + ids[idx++] = new CategoricalInput((Long) row.get("item"), targets); + } + } + if (idx != ids.length) ids = Arrays.copyOf(ids, idx); + Arrays.sort(ids); + return ids; + } + + WeightedInput[] prepareWeights(List> data, long degreeCutoff) { + WeightedInput[] inputs = new WeightedInput[data.size()]; + int idx = 0; + for (Map row : data) { + List weightList = (List) row.get("weights"); + int size = weightList.size(); + if ( size > degreeCutoff) { + double[] weights = new double[size]; + int i=0; + for (Number value : weightList) { + weights[i++]=value.doubleValue(); + } + inputs[idx++] = new WeightedInput((Long) row.get("item"), weights); + } + } + if (idx != inputs.length) inputs = Arrays.copyOf(inputs, idx); + Arrays.sort(inputs); + return inputs; + } + + protected int getTopK(ProcedureConfiguration configuration) { + return configuration.getInt("topK", 0); + } + + protected int getTopN(ProcedureConfiguration configuration) { + return configuration.getInt("top",0); + } + + interface SimilarityComputer { + SimilarityResult similarity(T source, T target, double cutoff); + } + + private class TopKTask implements Runnable { + private final int batchSize; + private final int taskOffset; + private final int multiplier; + private final int length; + private final T[] ids; + private final double similiarityCutoff; + private final SimilarityComputer computer; + private final TopKConsumer[] topKConsumers; + + TopKTask(int batchSize, int taskOffset, int multiplier, int length, T[] ids, double similiarityCutoff, int topK, SimilarityComputer computer) { + this.batchSize = batchSize; + this.taskOffset = taskOffset; + this.multiplier = multiplier; + this.length = length; + this.ids = ids; + this.similiarityCutoff = similiarityCutoff; + this.computer = computer; + topKConsumers = initializeTopKConsumers(length, topK); + } + + @Override + public void run() { + for (int offset = 0; offset < batchSize; offset++) { + int sourceId = taskOffset * multiplier + offset; + if (sourceId < length) { + computeSimilarityForSourceIndex(sourceId, ids, length, similiarityCutoff, (s, t, result) -> { + topKConsumers[s].accept(result); + topKConsumers[t].accept(result.reverse()); + }, computer); + } + } + } + void mergeInto(TopKConsumer[] target) { + for (int i = 0; i < target.length; i++) { + target[i].accept(topKConsumers[i]); + } + } + } +} diff --git a/algo/src/main/java/org/neo4j/graphalgo/SimilarityResult.java b/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityResult.java similarity index 60% rename from algo/src/main/java/org/neo4j/graphalgo/SimilarityResult.java rename to algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityResult.java index 066fba8be..02d9441df 100644 --- a/algo/src/main/java/org/neo4j/graphalgo/SimilarityResult.java +++ b/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityResult.java @@ -16,25 +16,26 @@ * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ -package org.neo4j.graphalgo; +package org.neo4j.graphalgo.similarity; -import org.neo4j.graphalgo.core.utils.Intersections; +import org.HdrHistogram.DoubleHistogram; +import java.util.Comparator; import java.util.Objects; public class SimilarityResult implements Comparable { - public final long source2; + public final long item2; public final long count1; - public final long source1; + public final long item1; public final long count2; public final long intersection; - public final double similarity; + public double similarity; public static SimilarityResult TOMB = new SimilarityResult(-1, -1, -1, -1, -1, -1); - public SimilarityResult(long source1, long source2, long count1, long count2, long intersection, double similarity) { - this.source1 = source1; - this.source2 = source2; + public SimilarityResult(long item1, long item2, long count1, long count2, long intersection, double similarity) { + this.item1 = item1; + this.item2 = item2; this.count1 = count1; this.count2 = count2; this.intersection = intersection; @@ -46,8 +47,8 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; SimilarityResult that = (SimilarityResult) o; - return source1 == that.source1 && - source2 == that.source2 && + return item1 == that.item1 && + item2 == that.item2 && count1 == that.count1 && count2 == that.count2 && intersection == that.intersection && @@ -57,26 +58,32 @@ public boolean equals(Object o) { @Override public int hashCode() { - return Objects.hash(source1, source2, count1, count2, intersection, similarity); - } - - public static SimilarityResult of(long source1, long source2, long[] targets1, long[] targets2, double similarityCutoff) { - long intersection = Intersections.intersection3(targets1,targets2); - if (similarityCutoff >= 0d && intersection == 0) return null; - int count1 = targets1.length; - int count2 = targets2.length; - long denominator = count1 + count2 - intersection; - double jaccard = denominator == 0 ? 0 : (double)intersection / denominator; - if (jaccard < similarityCutoff) return null; - return new SimilarityResult(source1, source2, count1, count2, intersection, jaccard); + return Objects.hash(item1, item2, count1, count2, intersection, similarity); } + /** + * sorts by default descending + */ @Override public int compareTo(SimilarityResult o) { return Double.compare(o.similarity,this.similarity); } public SimilarityResult reverse() { - return new SimilarityResult(source2,source1,count2,count1,intersection,similarity); + return new SimilarityResult(item2, item1,count2,count1,intersection,similarity); + } + + public SimilarityResult squareRooted() { + this.similarity = Math.sqrt(this.similarity); + return this; + } + + void record(DoubleHistogram histogram) { + try { + histogram.recordValue(similarity); + } catch (ArrayIndexOutOfBoundsException ignored) { + } } + static Comparator ASCENDING = (o1, o2) -> -o1.compareTo(o2); + static Comparator DESCENDING = SimilarityResult::compareTo; } diff --git a/algo/src/main/java/org/neo4j/graphalgo/SimilaritySummaryResult.java b/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilaritySummaryResult.java similarity index 67% rename from algo/src/main/java/org/neo4j/graphalgo/SimilaritySummaryResult.java rename to algo/src/main/java/org/neo4j/graphalgo/similarity/SimilaritySummaryResult.java index e39eb2523..560673d8e 100644 --- a/algo/src/main/java/org/neo4j/graphalgo/SimilaritySummaryResult.java +++ b/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilaritySummaryResult.java @@ -16,7 +16,11 @@ * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ -package org.neo4j.graphalgo; +package org.neo4j.graphalgo.similarity; + +import org.HdrHistogram.DoubleHistogram; + +import java.util.concurrent.atomic.AtomicLong; public class SimilaritySummaryResult { @@ -61,4 +65,26 @@ public SimilaritySummaryResult(long nodes, long similarityPairs, this.p999 = p999; this.p100 = p100; } + + static SimilaritySummaryResult from(long length, AtomicLong similarityPairs, String writeRelationshipType, String writeProperty, boolean write, DoubleHistogram histogram) { + return new SimilaritySummaryResult( + length, + similarityPairs.get(), + write, + writeRelationshipType, + writeProperty, + histogram.getMinValue(), + histogram.getMaxValue(), + histogram.getMean(), + histogram.getStdDeviation(), + histogram.getValueAtPercentile(25D), + histogram.getValueAtPercentile(50D), + histogram.getValueAtPercentile(75D), + histogram.getValueAtPercentile(90D), + histogram.getValueAtPercentile(95D), + histogram.getValueAtPercentile(99D), + histogram.getValueAtPercentile(99.9D), + histogram.getValueAtPercentile(100D) + ); + } } diff --git a/algo/src/main/java/org/neo4j/graphalgo/similarity/WeightedInput.java b/algo/src/main/java/org/neo4j/graphalgo/similarity/WeightedInput.java new file mode 100644 index 000000000..03fc93fcb --- /dev/null +++ b/algo/src/main/java/org/neo4j/graphalgo/similarity/WeightedInput.java @@ -0,0 +1,49 @@ +package org.neo4j.graphalgo.similarity; + +import org.neo4j.graphalgo.core.utils.Intersections; + +import java.util.stream.DoubleStream; + +class WeightedInput implements Comparable { + long id; + double[] weights; + int count; + + public WeightedInput(long id, double[] weights) { + this.id = id; + this.weights = weights; + for (double weight : weights) { + if (weight!=0d) this.count++; + } + } + + @Override + public int compareTo(WeightedInput o) { + return Long.compare(id, o.id); + } + + SimilarityResult sumSquareDelta(double similarityCutoff, WeightedInput other) { + int len = Math.min(weights.length, other.weights.length); + double sumSquareDelta = Intersections.sumSquareDelta(weights, other.weights, len); + long intersection = 0; + /* todo + for (int i = 0; i < len; i++) { + if (weights[i] == other.weights[i] && weights[i] != 0d) intersection++; + } + */ + if (similarityCutoff >= 0d && sumSquareDelta > similarityCutoff) return null; + return new SimilarityResult(id, other.id, count, other.count, intersection, sumSquareDelta); + } + SimilarityResult cosineSquares(double similarityCutoff, WeightedInput other) { + int len = Math.min(weights.length, other.weights.length); + double cosineSquares = Intersections.cosineSquare(weights, other.weights, len); + long intersection = 0; + /* todo + for (int i = 0; i < len; i++) { + if (weights[i] == other.weights[i] && weights[i] != 0d) intersection++; + } + */ + if (similarityCutoff >= 0d && (cosineSquares == 0 || cosineSquares < similarityCutoff)) return null; + return new SimilarityResult(id, other.id, count, other.count, intersection, cosineSquares); + } +} diff --git a/algo/src/test/java/org/neo4j/graphalgo/impl/util/TopKConsumerTest.java b/algo/src/test/java/org/neo4j/graphalgo/impl/util/TopKConsumerTest.java index 14e81c363..4ee414edf 100644 --- a/algo/src/test/java/org/neo4j/graphalgo/impl/util/TopKConsumerTest.java +++ b/algo/src/test/java/org/neo4j/graphalgo/impl/util/TopKConsumerTest.java @@ -56,7 +56,7 @@ public int compareTo(Item o) { @Test public void testFindTopKHeap4() throws Exception { - Collection topItems = TopKConsumer.topK(asList(ITEM1, ITEM3, ITEM2, ITEM4), 4); + Collection topItems = TopKConsumer.topK(asList(ITEM1, ITEM3, ITEM2, ITEM4), 4, Item::compareTo); assertEquals(asList(ITEM4,ITEM3,ITEM2,ITEM1),topItems); for (Item topItem : topItems) { System.out.println(topItem); @@ -65,7 +65,7 @@ public void testFindTopKHeap4() throws Exception { @Test public void testFindTopKHeap2of4() throws Exception { - Collection topItems = TopKConsumer.topK(asList(ITEM2, ITEM4), 4); + Collection topItems = TopKConsumer.topK(asList(ITEM2, ITEM4), 4, Item::compareTo); assertEquals(asList(ITEM4,ITEM2),topItems); for (Item topItem : topItems) { System.out.println(topItem); @@ -73,7 +73,7 @@ public void testFindTopKHeap2of4() throws Exception { } @Test public void testFindTopKHeap4of3() throws Exception { - Collection topItems = TopKConsumer.topK(asList(ITEM2, ITEM1, ITEM4, ITEM3), 3); + Collection topItems = TopKConsumer.topK(asList(ITEM2, ITEM1, ITEM4, ITEM3), 3, Item::compareTo); assertEquals(asList(ITEM4,ITEM3,ITEM2),topItems); for (Item topItem : topItems) { System.out.println(topItem); @@ -81,7 +81,7 @@ public void testFindTopKHeap4of3() throws Exception { } @Test public void testFindTopKHeap() throws Exception { - Collection topItems = TopKConsumer.topK(asList(ITEM1, ITEM3, ITEM2, ITEM4), 2); + Collection topItems = TopKConsumer.topK(asList(ITEM1, ITEM3, ITEM2, ITEM4), 2, Item::compareTo); assertEquals(asList(ITEM4,ITEM3),topItems); for (Item topItem : topItems) { System.out.println(topItem); @@ -89,7 +89,7 @@ public void testFindTopKHeap() throws Exception { } @Test public void testFindTopKHeapDuplicates() throws Exception { - Collection topItems = TopKConsumer.topK(asList(ITEM2, ITEM3, ITEM3, ITEM4), 3); + Collection topItems = TopKConsumer.topK(asList(ITEM2, ITEM3, ITEM3, ITEM4), 3, Item::compareTo); assertEquals(asList(ITEM4,ITEM3,ITEM3),topItems); for (Item topItem : topItems) { System.out.println(topItem); @@ -98,7 +98,7 @@ public void testFindTopKHeapDuplicates() throws Exception { @Test public void testFindTopKHeap2() throws Exception { - List topItems = TopKConsumer.topK(asList(ITEM1, ITEM3, ITEM2, ITEM4), 2); + List topItems = TopKConsumer.topK(asList(ITEM1, ITEM3, ITEM2, ITEM4), 2, Item::compareTo); assertEquals(asList(ITEM4,ITEM3),topItems); for (Item topItem : topItems) { System.out.println(topItem); @@ -111,11 +111,11 @@ public void testFindTopKHeapPerf() throws Exception { List items = createItems(COUNT); List topItems = null; for (int i = 0; i < RUNS/10; i++) { - topItems = TopKConsumer.topK(items, WINDOW_SIZE); + topItems = TopKConsumer.topK(items, WINDOW_SIZE, Item::compareTo); } long time = System.currentTimeMillis(); for (int i = 0; i < RUNS; i++) { - topItems = TopKConsumer.topK(items, WINDOW_SIZE); + topItems = TopKConsumer.topK(items, WINDOW_SIZE, Item::compareTo); } time = System.currentTimeMillis() - time; System.out.println("array based time = " + time+" "+RUNS+" runs with "+COUNT+" items avg "+1.0*time/RUNS); diff --git a/benchmark/src/main/java/org/neo4j/graphalgo/bench/CosineSimilarityBenchmark.java b/benchmark/src/main/java/org/neo4j/graphalgo/bench/CosineSimilarityBenchmark.java new file mode 100644 index 000000000..d1849ca49 --- /dev/null +++ b/benchmark/src/main/java/org/neo4j/graphalgo/bench/CosineSimilarityBenchmark.java @@ -0,0 +1,85 @@ +/** + * Copyright (c) 2017 "Neo4j, Inc." + * + * This file is part of Neo4j Graph Algorithms . + * + * Neo4j Graph Algorithms is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.graphalgo.bench; + +import org.neo4j.graphalgo.core.utils.Intersections; +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; + +import java.io.IOException; +import java.util.Random; +import java.util.concurrent.TimeUnit; + +@Threads(1) +@Fork(value = 1, jvmArgs = {"-Xms8g", "-Xmx8g", "-XX:+UseG1GC"}) +@Warmup(iterations = 1000) +@Measurement(iterations = 10000, time = 2) +@State(Scope.Benchmark) +@BenchmarkMode(Mode.SingleShotTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +public class CosineSimilarityBenchmark { + + public static final int SIZE = 10_000; + + private static double[] initial = generate(SIZE,-42); + private static double[][] data = data(100); + + private static double[][] data(int size) { + double[][] result = new double[size][SIZE]; + for (int i=0;i + * + * This file is part of Neo4j Graph Algorithms . + * + * Neo4j Graph Algorithms is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.graphalgo.bench; + +import com.carrotsearch.hppc.LongHashSet; +import org.neo4j.graphalgo.core.utils.Intersections; +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; + +import java.io.IOException; +import java.util.Random; +import java.util.concurrent.TimeUnit; + +@Threads(1) +@Fork(value = 1, jvmArgs = {"-Xms8g", "-Xmx8g", "-XX:+UseG1GC"}) +@Warmup(iterations = 1000) +@Measurement(iterations = 10000, time = 2) +@State(Scope.Benchmark) +@BenchmarkMode(Mode.SingleShotTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +public class SquareDeltasBenchmark { + + // @Param({"", "Person"}) String label; + + public static final int SIZE = 10_000; + + private static double[] initial = generate(SIZE,-42); + private static double[][] data = data(100); + + private static double[][] data(int size) { + double[][] result = new double[size][SIZE]; + for (int i=0;i(italian) // tag::stream[] MATCH (p:Person)-[:LIKES]->(cuisine) -WITH {source:id(p), targets: collect(id(cuisine))} as userData +WITH {item:id(p), categories: collect(id(cuisine))} as userData WITH collect(userData) as data CALL algo.similarity.jaccard.stream(data) -YIELD source1, source2, count1, count2, intersection, similarity -RETURN algo.getNodeById(source1).name AS from, algo.getNodeById(source2).name AS to, intersection, similarity +YIELD item1, item2, count1, count2, intersection, similarity +RETURN algo.getNodeById(item1).name AS from, algo.getNodeById(item2).name AS to, intersection, similarity ORDER BY similarity DESC // end::stream[] // tag::stream-similarity-cutoff[] MATCH (p:Person)-[:LIKES]->(cuisine) -WITH {source:id(p), targets: collect(id(cuisine))} as userData +WITH {item:id(p), categories: collect(id(cuisine))} as userData WITH collect(userData) as data CALL algo.similarity.jaccard.stream(data, {similarityCutoff: 0.0}) -YIELD source1, source2, count1, count2, intersection, similarity -RETURN algo.getNodeById(source1).name AS from, algo.getNodeById(source2).name AS to, intersection, similarity +YIELD item1, item2, count1, count2, intersection, similarity +RETURN algo.getNodeById(item1).name AS from, algo.getNodeById(item2).name AS to, intersection, similarity ORDER BY similarity DESC // end::stream-similarity-cutoff[] // tag::stream-topk[] MATCH (p:Person)-[:LIKES]->(cuisine) -WITH {source:id(p), targets: collect(id(cuisine))} as userData +WITH {item:id(p), categories: collect(id(cuisine))} as userData WITH collect(userData) as data CALL algo.similarity.jaccard.stream(data, {topK: 1, similarityCutoff: 0.0}) -YIELD source1, source2, count1, count2, intersection, similarity -RETURN algo.getNodeById(source1).name AS from, algo.getNodeById(source2).name AS to, similarity +YIELD item1, item2, count1, count2, intersection, similarity +RETURN algo.getNodeById(item1).name AS from, algo.getNodeById(item2).name AS to, similarity ORDER BY from // end::stream-topk[] // tag::write-back[] MATCH (p:Person)-[:LIKES]->(cuisine) -WITH {source:id(p), targets: collect(id(cuisine))} as userData +WITH {item:id(p), categories: collect(id(cuisine))} as userData WITH collect(userData) as data CALL algo.similarity.jaccard(data, {topK: 1, similarityCutoff: 0.1, write:true}) YIELD nodes, similarityPairs, write, writeRelationshipType, writeProperty, min, max, mean, stdDev, p25, p50, p75, p90, p95, p99, p999, p100 diff --git a/doc/asciidoc/similarity-jaccard.adoc b/doc/asciidoc/similarity-jaccard.adoc index 2738d96f9..7c8da3637 100644 --- a/doc/asciidoc/similarity-jaccard.adoc +++ b/doc/asciidoc/similarity-jaccard.adoc @@ -155,7 +155,7 @@ For example, the person most similar to Praveena is Zhen, but the person most si [opts="header",cols="1,1,1,1,4"] |=== | Name | Type | Default | Optional | Description -| data | list | null | no | A list of maps of the following structure: `{source: nodeId, targets: [nodeId, nodeId, nodeId]}` +| data | list | null | no | A list of maps of the following structure: `{item: nodeId, categories: [nodeId, nodeId, nodeId]}` | top | int | 0 | yes | The number of similar pairs to return. If `0` it will return as many as it finds. | topK | int | 0 | yes | The number of similar values to return per node. If `0` will return as many as it finds. | similarityCutoff | int | -1 | yes | The threshold for Jaccard similarity. Values below this will not be returned. @@ -167,8 +167,8 @@ For example, the person most similar to Praveena is Zhen, but the person most si [opts="header",cols="1,1,6"] |=== | Name | Type | Description -| source1 | int | The ID of one node in the similarity pair -| source2 | int | The ID of other node in the similarity pair +| item1 | int | The ID of one node in the similarity pair +| item2 | int | The ID of other node in the similarity pair | count1 | int | The size of the `targets` list of one node | count2 | int | The size of the `targets` list of other node | intersection | int | The number of intersecting values in the two nodes `targets` lists @@ -212,7 +212,7 @@ include::scripts/similarity-jaccard.cypher[tag=query] [opts="header",cols="1,1,1,1,4"] |=== | Name | Type | Default | Optional | Description -| data | list | null | no | A list of maps of the following structure: `{source: nodeId, targets: [nodeId, nodeId, nodeId]}` +| data | list | null | no | A list of maps of the following structure: `{item: nodeId, categories: [nodeId, nodeId, nodeId]}` | top | int | 0 | yes | The number of similar pairs to return. If `0` it will return as many as it finds. | topK | int | 0 | yes | The number of similar values to return per node. If `0` will return as many as it finds. | similarityCutoff | int | -1 | yes | The threshold for Jaccard similarity. Values below this will not be returned. diff --git a/tests/src/test/java/org/neo4j/graphalgo/algo/JaccardTest.java b/tests/src/test/java/org/neo4j/graphalgo/algo/JaccardTest.java deleted file mode 100644 index 9e00f00ce..000000000 --- a/tests/src/test/java/org/neo4j/graphalgo/algo/JaccardTest.java +++ /dev/null @@ -1,385 +0,0 @@ -/** - * Copyright (c) 2017 "Neo4j, Inc." - * - * This file is part of Neo4j Graph Algorithms . - * - * Neo4j Graph Algorithms is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.graphalgo.algo; - -import org.junit.*; -import org.neo4j.graphalgo.JaccardProc; -import org.neo4j.graphalgo.TestDatabaseCreator; -import org.neo4j.graphdb.*; -import org.neo4j.internal.kernel.api.exceptions.KernelException; -import org.neo4j.kernel.impl.proc.Procedures; -import org.neo4j.kernel.internal.GraphDatabaseAPI; - -import java.util.*; - -import static java.util.Collections.singletonMap; -import static org.junit.Assert.*; - -public class JaccardTest { - - private static GraphDatabaseAPI db; - private Transaction tx; - - @BeforeClass - public static void beforeClass() throws KernelException { - db = TestDatabaseCreator.createTestDatabase(); - db.getDependencyResolver().resolveDependency(Procedures.class).registerProcedure(JaccardProc.class); - db.execute(buildDatabaseQuery()).close(); - } - - @AfterClass - public static void AfterClass() { - db.shutdown(); - } - - @Before - public void setUp() throws Exception { - tx = db.beginTx(); - } - - @After - public void tearDown() throws Exception { - tx.close(); - } - - private static void buildRandomDB(int size) { - db.execute("MATCH (n) DETACH DELETE n").close(); - db.execute("UNWIND range(1,$size/10) as _ CREATE (:Person) CREATE (:Item) ",singletonMap("size",size)).close(); - String statement = - "MATCH (p:Person) WITH collect(p) as people " + - "MATCH (i:Item) WITH people, collect(i) as items " + - "UNWIND range(1,$size) as _ " + - "WITH people[toInteger(rand()*size(people))] as p, items[toInteger(rand()*size(items))] as i " + - "MERGE (p)-[:LIKES]->(i) RETURN count(*) "; - db.execute(statement,singletonMap("size",size)).close(); - } - private static String buildDatabaseQuery() { - return "CREATE (a:Person {name:'Alice'})\n" + - "CREATE (b:Person {name:'Bob'})\n" + - "CREATE (c:Person {name:'Charlie'})\n" + - "CREATE (d:Person {name:'Dana'})\n" + - "CREATE (i1:Item {name:'p1'})\n" + - "CREATE (i2:Item {name:'p2'})\n" + - "CREATE (i3:Item {name:'p3'})\n" + - - "CREATE" + - " (a)-[:LIKES]->(i1),\n" + - " (a)-[:LIKES]->(i2),\n" + - " (a)-[:LIKES]->(i3),\n" + - " (b)-[:LIKES]->(i1),\n" + - " (b)-[:LIKES]->(i2),\n" + - " (c)-[:LIKES]->(i3)\n"; - // a: 3 - // b: 2 - // c: 1 - // a / b = 2 : 2/3 - // a / c = 1 : 1/3 - // b / c = 0 : 0/3 = 0 - // - } - - - @Test - public void jaccardSingleMultiThreadComparision() { - int size = 3333; - buildRandomDB(size); - String query = "MATCH (p:Person)-[:LIKES]->(i:Item) \n" + - "WITH {source:id(p), targets: collect(distinct id(i))} as userData\n" + - "WITH collect(userData) as data\n" + - "call algo.similarity.jaccard.stream(data,{similarityCutoff:-0.1,concurrency:$threads}) " + - "yield source1, source2, count1, count2, intersection, similarity " + - "RETURN source1, source2, count1, count2, intersection, similarity ORDER BY source1,source2"; - Result result1 = db.execute(query, singletonMap("threads", 1)); - Result result2 = db.execute(query, singletonMap("threads", 2)); - Result result4 = db.execute(query, singletonMap("threads", 4)); - Result result8 = db.execute(query, singletonMap("threads", 8)); - int count=0; - while (result1.hasNext()) { - Map row1 = result1.next(); - assertEquals(row1.toString(), row1,result2.next()); - assertEquals(row1.toString(), row1,result4.next()); - assertEquals(row1.toString(), row1,result8.next()); - count++; - } - int people = size/10; - assertEquals((people * people - people)/2,count); - } - - @Test - public void jaccardSingleMultiThreadComparisionTopK() { - int size = 3333; - buildRandomDB(size); - String query = "MATCH (p:Person)-[:LIKES]->(i:Item) \n" + - "WITH {source:id(p), targets: collect(distinct id(i))} as userData\n" + - "WITH collect(userData) as data\n" + - "call algo.similarity.jaccard.stream(data,{similarityCutoff:-0.1,concurrency:$threads,topK:1}) " + - "yield source1, source2, count1, count2, intersection, similarity " + - "RETURN source1, source2, count1, count2, intersection, similarity ORDER BY source1,source2"; - Result result1 = db.execute(query, singletonMap("threads", 1)); - Result result2 = db.execute(query, singletonMap("threads", 2)); - Result result4 = db.execute(query, singletonMap("threads", 4)); - Result result8 = db.execute(query, singletonMap("threads", 8)); - int count=0; - while (result1.hasNext()) { - Map row1 = result1.next(); - assertEquals(row1.toString(), row1,result2.next()); - assertEquals(row1.toString(), row1,result4.next()); - assertEquals(row1.toString(), row1,result8.next()); - count++; - } - int people = size/10; - assertEquals(people,count); - } - - @Test - public void topNJaccardStreamTest() { - String query = "MATCH (p:Person)-[:LIKES]->(i:Item) \n" + - "WITH {source:id(p), targets: collect(distinct id(i))} as userData\n" + - "WITH collect(userData) as data\n" + - "call algo.similarity.jaccard.stream(data,{top:2}) " + - "yield source1, source2, count1, count2, intersection, similarity " + - "RETURN * ORDER BY source1,source2"; - - - Result results = db.execute(query); - int count = 0; - assertTrue(results.hasNext()); - Map row = results.next(); - assertEquals(0L, row.get("source1")); - assertEquals(1L, row.get("source2")); - assertEquals(3L, row.get("count1")); - assertEquals(2L, row.get("count2")); - assertEquals(2L, row.get("intersection")); - assertEquals(2.0D / 3, row.get("similarity")); - count++; - assertTrue(results.hasNext()); - row = results.next(); - assertEquals(0L, row.get("source1")); - assertEquals(2L, row.get("source2")); - assertEquals(3L, row.get("count1")); - assertEquals(1L, row.get("count2")); - assertEquals(1L, row.get("intersection")); - assertEquals(1.0D / 3, row.get("similarity")); - count++; - assertFalse(results.hasNext()); - assertEquals(2, count); - } - - @Test - public void jaccardStreamTest() { - String query = "MATCH (p:Person)-[:LIKES]->(i:Item) \n" + - "WITH {source:id(p), targets: collect(distinct id(i))} as userData\n" + - "WITH collect(userData) as data\n" + - "call algo.similarity.jaccard.stream(data) " + - "yield source1, source2, count1, count2, intersection, similarity " + - "RETURN * ORDER BY source1,source2"; - - - Result results = db.execute(query); - int count = 0; - assertTrue(results.hasNext()); - Map row = results.next(); - assertEquals(0L, row.get("source1")); - assertEquals(1L, row.get("source2")); - assertEquals(3L, row.get("count1")); - assertEquals(2L, row.get("count2")); - assertEquals(2L, row.get("intersection")); - assertEquals(2.0D / 3, row.get("similarity")); - count++; - assertTrue(results.hasNext()); - row = results.next(); - assertEquals(0L, row.get("source1")); - assertEquals(2L, row.get("source2")); - assertEquals(3L, row.get("count1")); - assertEquals(1L, row.get("count2")); - assertEquals(1L, row.get("intersection")); - assertEquals(1.0D / 3, row.get("similarity")); - count++; - row = results.next(); - assertEquals(1L, row.get("source1")); - assertEquals(2L, row.get("source2")); - assertEquals(2L, row.get("count1")); - assertEquals(1L, row.get("count2")); - assertEquals(0L, row.get("intersection")); - assertEquals(0D / 3, row.get("similarity")); - count++; - assertFalse(results.hasNext()); - assertEquals(3, count); - } - - @Test - public void topKJaccardStreamTest() { - String query = "MATCH (p:Person)-[:LIKES]->(i:Item) \n" + - "WITH {source:id(p), targets: collect(distinct id(i))} as userData\n" + - "WITH collect(userData) as data\n" + - "call algo.similarity.jaccard.stream(data,{topK:1, concurrency:1}) " + - "yield source1, source2, count1, count2, intersection, similarity " + - "RETURN * ORDER BY source1,source2"; - - System.out.println(db.execute(query).resultAsString()); - - Result results = db.execute(query); - int count = 0; - assertTrue(results.hasNext()); - Map row = results.next(); - assertEquals(0L, row.get("source1")); - assertEquals(1L, row.get("source2")); - assertEquals(3L, row.get("count1")); - assertEquals(2L, row.get("count2")); - assertEquals(2L, row.get("intersection")); - assertEquals(2.0D / 3, row.get("similarity")); - count++; - row = results.next(); - assertEquals(1L, row.get("source1")); - assertEquals(0L, row.get("source2")); - assertEquals(2L, row.get("count1")); - assertEquals(3L, row.get("count2")); - assertEquals(2L, row.get("intersection")); - assertEquals(2.0D / 3, row.get("similarity")); - count++; - assertTrue(results.hasNext()); - row = results.next(); - assertEquals(2L, row.get("source1")); - assertEquals(0L, row.get("source2")); - assertEquals(1L, row.get("count1")); - assertEquals(3L, row.get("count2")); - assertEquals(1L, row.get("intersection")); - assertEquals(1D / 3, row.get("similarity")); - count++; - assertFalse(results.hasNext()); - assertEquals(3, count); - } - - @Test - public void topK4JaccardStreamTest() { - String statement = "MATCH (i:Item {name:'p1'}) MATCH (d:Person {name:'Dana'}) CREATE (d)-[:LIKES]->(i)\n"; - db.execute(statement).close(); - - String query = "MATCH (p:Person)-[:LIKES]->(i:Item) \n" + - "WITH {source:id(p), targets: collect(distinct id(i))} as userData\n" + - "WITH collect(userData) as data\n" + - "call algo.similarity.jaccard.stream(data,{topK:4, concurrency:4, similarityCutoff:-0.1}) " + - "yield source1, source2, count1, count2, intersection, similarity " + - "RETURN * ORDER BY source1,source2"; - - System.out.println(db.execute(query).resultAsString()); - - Result results = db.execute(query); - assertSameSource(results, 3, 0L); - assertSameSource(results, 3, 1L); - assertSameSource(results, 3, 2L); - assertSameSource(results, 3, 3L); - assertFalse(results.hasNext()); - } - - @Test - public void topK3JaccardStreamTest() { - String query = "MATCH (p:Person)-[:LIKES]->(i:Item) \n" + - "WITH {source:id(p), targets: collect(distinct id(i))} as userData\n" + - "WITH collect(userData) as data\n" + - "call algo.similarity.jaccard.stream(data,{topK:3, concurrency:3}) " + - "yield source1, source2, count1, count2, intersection, similarity " + - "RETURN * ORDER BY source1,source2"; - - System.out.println(db.execute(query).resultAsString()); - - Result results = db.execute(query); - assertSameSource(results, 2, 0L); - assertSameSource(results, 2, 1L); - assertSameSource(results, 2, 2L); - assertFalse(results.hasNext()); - } - - private void assertSameSource(Result results, int count, long source) { - Map row; - long target = 0; - for (int i = 0; i row = results.next(); - - assertEquals((double) row.get("p50"), 0.33, 0.01); - assertEquals((double) row.get("p95"), 0.66, 0.01); - assertEquals((double) row.get("p99"), 0.66, 0.01); - assertEquals((double) row.get("p100"), 0.66, 0.01); - } - - @Test - public void simpleJaccardWriteTest() { - String query = "MATCH (p:Person)-[:LIKES]->(i:Item) \n" + - "WITH {source:id(p), targets: collect(distinct id(i))} as userData\n" + - "WITH collect(userData) as data\n" + - "CALL algo.similarity.jaccard(data, {similarityCutoff: 0.1, write: true}) " + - "yield p50, p75, p90, p99, p999, p100, nodes, similarityPairs " + - "RETURN *"; - - - db.execute(query); - - String checkSimilaritiesQuery = "MATCH (a)-[similar:SIMILAR]-(b)" + - "RETURN a.name AS node1, b.name as node2, similar.score AS score " + - "ORDER BY id(a), id(b)"; - - Result result = db.execute(checkSimilaritiesQuery); - - assertTrue(result.hasNext()); - Map row = result.next(); - assertEquals(row.get("node1"), "Alice"); - assertEquals(row.get("node2"), "Bob"); - assertEquals((double) row.get("score"), 0.66, 0.01); - - assertTrue(result.hasNext()); - row = result.next(); - assertEquals(row.get("node1"), "Alice"); - assertEquals(row.get("node2"), "Charlie"); - assertEquals((double) row.get("score"), 0.33, 0.01); - - assertTrue(result.hasNext()); - row = result.next(); - assertEquals(row.get("node1"), "Bob"); - assertEquals(row.get("node2"), "Alice"); - assertEquals((double) row.get("score"), 0.66, 0.01); - - assertTrue(result.hasNext()); - row = result.next(); - assertEquals(row.get("node1"), "Charlie"); - assertEquals(row.get("node2"), "Alice"); - assertEquals((double) row.get("score"), 0.33, 0.01); - - } - -} diff --git a/tests/src/test/java/org/neo4j/graphalgo/algo/similarity/CosineTest.java b/tests/src/test/java/org/neo4j/graphalgo/algo/similarity/CosineTest.java new file mode 100644 index 000000000..df4e5f0a6 --- /dev/null +++ b/tests/src/test/java/org/neo4j/graphalgo/algo/similarity/CosineTest.java @@ -0,0 +1,364 @@ +/** + * Copyright (c) 2017 "Neo4j, Inc." + * + * This file is part of Neo4j Graph Algorithms . + * + * Neo4j Graph Algorithms is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.graphalgo.algo.similarity; + +import org.junit.*; +import org.neo4j.graphalgo.TestDatabaseCreator; +import org.neo4j.graphalgo.similarity.CosineProc; +import org.neo4j.graphdb.Result; +import org.neo4j.graphdb.Transaction; +import org.neo4j.internal.kernel.api.exceptions.KernelException; +import org.neo4j.kernel.impl.proc.Procedures; +import org.neo4j.kernel.internal.GraphDatabaseAPI; + +import java.util.Map; + +import static java.lang.Math.sqrt; +import static java.util.Collections.singletonMap; +import static org.junit.Assert.*; +import static org.neo4j.helpers.collection.MapUtil.map; + +public class CosineTest { + + private static GraphDatabaseAPI db; + private Transaction tx; + public static final String STATEMENT_STREAM = "MATCH (i:Item) WITH i ORDER BY id(i) MATCH (p:Person) OPTIONAL MATCH (p)-[r:LIKES]->(i)\n" + + "WITH {item:id(p), weights: collect(coalesce(r.stars,0))} as userData\n" + + "WITH collect(userData) as data\n" + + "call algo.similarity.cosine.stream(data,$config) " + + "yield item1, item2, count1, count2, intersection, similarity " + + "RETURN * ORDER BY item1,item2"; + + public static final String STATEMENT = "MATCH (i:Item) WITH i ORDER BY id(i) MATCH (p:Person) OPTIONAL MATCH (p)-[r:LIKES]->(i)\n" + + "WITH {item:id(p), weights: collect(coalesce(r.stars,0))} as userData\n" + + "WITH collect(userData) as data\n" + + + "CALL algo.similarity.cosine(data, $config) " + + "yield p25, p50, p75, p90, p95, p99, p999, p100, nodes, similarityPairs " + + "RETURN *"; + + @BeforeClass + public static void beforeClass() throws KernelException { + db = TestDatabaseCreator.createTestDatabase(); + db.getDependencyResolver().resolveDependency(Procedures.class).registerProcedure(CosineProc.class); + db.execute(buildDatabaseQuery()).close(); + } + + @AfterClass + public static void AfterClass() { + db.shutdown(); + } + + @Before + public void setUp() throws Exception { + tx = db.beginTx(); + } + + @After + public void tearDown() throws Exception { + tx.close(); + } + + private static void buildRandomDB(int size) { + db.execute("MATCH (n) DETACH DELETE n").close(); + db.execute("UNWIND range(1,$size/10) as _ CREATE (:Person) CREATE (:Item) ",singletonMap("size",size)).close(); + String statement = + "MATCH (p:Person) WITH collect(p) as people " + + "MATCH (i:Item) WITH people, collect(i) as items " + + "UNWIND range(1,$size) as _ " + + "WITH people[toInteger(rand()*size(people))] as p, items[toInteger(rand()*size(items))] as i " + + "MERGE (p)-[:LIKES]->(i) RETURN count(*) "; + db.execute(statement,singletonMap("size",size)).close(); + } + private static String buildDatabaseQuery() { + return "CREATE (a:Person {name:'Alice'})\n" + + "CREATE (b:Person {name:'Bob'})\n" + + "CREATE (c:Person {name:'Charlie'})\n" + + "CREATE (d:Person {name:'Dana'})\n" + + "CREATE (i1:Item {name:'p1'})\n" + + "CREATE (i2:Item {name:'p2'})\n" + + "CREATE (i3:Item {name:'p3'})\n" + + + "CREATE" + + " (a)-[:LIKES {stars:1}]->(i1),\n" + + " (a)-[:LIKES {stars:2}]->(i2),\n" + + " (a)-[:LIKES {stars:5}]->(i3),\n" + + " (b)-[:LIKES {stars:1}]->(i1),\n" + + " (b)-[:LIKES {stars:3}]->(i2),\n" + + " (c)-[:LIKES {stars:4}]->(i3)\n"; + + /* + for (int i = 0; i < len; i++) { + double weight1 = vector1[i]; + // if (weight1 == 0d) continue; + double weight2 = vector2[i]; + // if (weight2 == 0d) continue; + + dotProduct += weight1 * weight2; + xLength += weight1 * weight1; + yLength += weight2 * weight2; + } + + return dotProduct / Math.sqrt(xLength * yLength); + + */ + // a: 1,2,5 : xL 30 + // b: 1,3,0 : 10 + // c: 0,0,4 : 16 + // d: 0,0,0 : 0 + // a0 - b1: (1+6)/sqrt(30*10) = 0.4 + // a0 - c2: 20 / sqrt(30*16) = 0.91 + // a0 - d3: 0 / sqrt(30*0) = 0 + // b1 - c2: 0 / sqrt(10*16) = 0 + // b1 - d3: 0 + // c2 - d3: 0 + + } + + + @Test + public void cosineSingleMultiThreadComparision() { + int size = 333; + buildRandomDB(size); + Result result1 = db.execute(STATEMENT_STREAM, map("config", map("similarityCutoff",-0.1,"concurrency", 1))); + Result result2 = db.execute(STATEMENT_STREAM, map("config", map("similarityCutoff",-0.1,"concurrency", 2))); + Result result4 = db.execute(STATEMENT_STREAM, map("config", map("similarityCutoff",-0.1,"concurrency", 4))); + Result result8 = db.execute(STATEMENT_STREAM, map("config", map("similarityCutoff",-0.1,"concurrency", 8))); + int count=0; + while (result1.hasNext()) { + Map row1 = result1.next(); + assertEquals(row1.toString(), row1,result2.next()); + assertEquals(row1.toString(), row1,result4.next()); + assertEquals(row1.toString(), row1,result8.next()); + count++; + } + int people = size/10; + assertEquals((people * people - people)/2,count); + } + + @Test + public void cosineSingleMultiThreadComparisionTopK() { + int size = 333; + buildRandomDB(size); + + Result result1 = db.execute(STATEMENT_STREAM, map("config", map("similarityCutoff",-0.1,"topK",1,"concurrency", 1))); + Result result2 = db.execute(STATEMENT_STREAM, map("config", map("similarityCutoff",-0.1,"topK",1,"concurrency", 2))); + Result result4 = db.execute(STATEMENT_STREAM, map("config", map("similarityCutoff",-0.1,"topK",1,"concurrency", 4))); + Result result8 = db.execute(STATEMENT_STREAM, map("config", map("similarityCutoff",-0.1,"topK",1,"concurrency", 8))); + int count=0; + while (result1.hasNext()) { + Map row1 = result1.next(); + assertEquals(row1.toString(), row1,result2.next()); + assertEquals(row1.toString(), row1,result4.next()); + assertEquals(row1.toString(), row1,result8.next()); + count++; + } + int people = size/10; + assertEquals(people,count); + } + + @Test + public void topNcosineStreamTest() { + Result results = db.execute(STATEMENT_STREAM, map("config",map("top",2))); + assert01(results.next()); + assert02(results.next()); + assertFalse(results.hasNext()); + } + + @Test + public void cosineStreamTest() { + Result results = db.execute(STATEMENT_STREAM, map("config",map("concurrency",1))); + assertTrue(results.hasNext()); + assert01(results.next()); + assert02(results.next()); + assert03(results.next()); + assert12(results.next()); + assert13(results.next()); + assert23(results.next()); + assertFalse(results.hasNext()); + } + + @Test + public void topKCosineStreamTest() { + Map params = map("config", map( "concurrency", 1,"topK", 1)); + System.out.println(db.execute(STATEMENT_STREAM, params).resultAsString()); + Result results = db.execute(STATEMENT_STREAM, params); + assertTrue(results.hasNext()); + assert02(results.next()); + assert01(flip(results.next())); + assert02(flip(results.next())); + assert03(flip(results.next())); + assertFalse(results.hasNext()); + } + + private Map flip(Map row) { + return map("similarity", row.get("similarity"),"intersection", row.get("intersection"), + "item1",row.get("item2"),"count1",row.get("count2"), + "item2",row.get("item1"),"count2",row.get("count1")); + } + + private void assertSameSource(Result results, int count, long source) { + Map row; + long target = 0; + for (int i = 0; i params = map("config", map("topK", 4, "concurrency", 4, "similarityCutoff", -0.1)); + + Result results = db.execute(STATEMENT_STREAM,params); + assertSameSource(results, 3, 0L); + assertSameSource(results, 3, 1L); + assertSameSource(results, 3, 2L); + assertSameSource(results, 3, 3L); + assertFalse(results.hasNext()); + } + + @Test + public void topK3cosineStreamTest() { + Map params = map("config", map("concurrency", 3, "topK", 3)); + + System.out.println(db.execute(STATEMENT_STREAM, params).resultAsString()); + + Result results = db.execute(STATEMENT_STREAM, params); + assertSameSource(results, 3, 0L); + assertSameSource(results, 3, 1L); + assertSameSource(results, 3, 2L); + assertSameSource(results, 3, 3L); + assertFalse(results.hasNext()); + } + + @Test + public void simpleCosineTest() { + Map params = map("config", map()); + + Map row = db.execute(STATEMENT,params).next(); + assertEquals((double) row.get("p25"), 0.0, 0.01); + assertEquals((double) row.get("p50"), 0, 0.01); + assertEquals((double) row.get("p75"), 0.40, 0.01); + assertEquals((double) row.get("p90"), 0.40, 0.01); + assertEquals((double) row.get("p95"), 0.91, 0.01); + assertEquals((double) row.get("p99"), 0.91, 0.01); + assertEquals((double) row.get("p100"), 0.91, 0.01); + } + + @Test + public void simpleCosineWriteTest() { + Map params = map("config", map( "write",true, "similarityCutoff", 0.1)); + + db.execute(STATEMENT,params).close(); + + String checkSimilaritiesQuery = "MATCH (a)-[similar:SIMILAR]-(b)" + + "RETURN a.name AS node1, b.name as node2, similar.score AS score " + + "ORDER BY id(a), id(b)"; + + System.out.println(db.execute(checkSimilaritiesQuery).resultAsString()); + Result result = db.execute(checkSimilaritiesQuery); + + assertTrue(result.hasNext()); + Map row = result.next(); + assertEquals(row.get("node1"), "Alice"); + assertEquals(row.get("node2"), "Bob"); + assertEquals((double) row.get("score"), 0.40, 0.01); + + assertTrue(result.hasNext()); + row = result.next(); + assertEquals(row.get("node1"), "Alice"); + assertEquals(row.get("node2"), "Charlie"); + assertEquals((double) row.get("score"), 0.91, 0.01); + + assertTrue(result.hasNext()); + row = result.next(); + assertEquals(row.get("node1"), "Bob"); + assertEquals(row.get("node2"), "Alice"); + assertEquals((double) row.get("score"), 0.40, 0.01); + + + assertTrue(result.hasNext()); + row = result.next(); + assertEquals(row.get("node1"), "Charlie"); + assertEquals(row.get("node2"), "Alice"); + assertEquals((double) row.get("score"), 0.91, 0.01); + + assertFalse(result.hasNext()); + } + + private void assert23(Map row) { + assertEquals(2L, row.get("item1")); + assertEquals(3L, row.get("item2")); + assertEquals(1L, row.get("count1")); + assertEquals(0L, row.get("count2")); + assertEquals(0L, row.get("intersection")); + assertEquals(0d, row.get("similarity")); + } + + private void assert13(Map row) { + assertEquals(1L, row.get("item1")); + assertEquals(3L, row.get("item2")); + assertEquals(2L, row.get("count1")); + assertEquals(0L, row.get("count2")); + assertEquals(0L, row.get("intersection")); + assertEquals(0d, row.get("similarity")); + } + + private void assert12(Map row) { + assertEquals(1L, row.get("item1")); + assertEquals(2L, row.get("item2")); + assertEquals(2L, row.get("count1")); + assertEquals(1L, row.get("count2")); + // assertEquals(0L, row.get("intersection")); + assertEquals(0.0, row.get("similarity")); + } + + private void assert03(Map row) { + assertEquals(0L, row.get("item1")); + assertEquals(3L, row.get("item2")); + assertEquals(3L, row.get("count1")); + assertEquals(0L, row.get("count2")); + assertEquals(0L, row.get("intersection")); + assertEquals(0d, row.get("similarity")); + } + + private void assert02(Map row) { + assertEquals(0L, row.get("item1")); + assertEquals(2L, row.get("item2")); + assertEquals(3L, row.get("count1")); + assertEquals(1L, row.get("count2")); + // assertEquals(1L, row.get("intersection")); + assertEquals(0.91, (double)row.get("similarity"),0.01); + } + + private void assert01(Map row) { + assertEquals(0L, row.get("item1")); + assertEquals(1L, row.get("item2")); + assertEquals(3L, row.get("count1")); + assertEquals(2L, row.get("count2")); + // assertEquals(2L, row.get("intersection")); + assertEquals(0.40, (double)row.get("similarity"),0.01); + } +} diff --git a/tests/src/test/java/org/neo4j/graphalgo/algo/similarity/EuclideanTest.java b/tests/src/test/java/org/neo4j/graphalgo/algo/similarity/EuclideanTest.java new file mode 100644 index 000000000..d3d5b0242 --- /dev/null +++ b/tests/src/test/java/org/neo4j/graphalgo/algo/similarity/EuclideanTest.java @@ -0,0 +1,381 @@ +/** + * Copyright (c) 2017 "Neo4j, Inc." + * + * This file is part of Neo4j Graph Algorithms . + * + * Neo4j Graph Algorithms is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.graphalgo.algo.similarity; + +import org.junit.*; +import org.neo4j.graphalgo.TestDatabaseCreator; +import org.neo4j.graphalgo.similarity.EuclideanProc; +import org.neo4j.graphdb.Result; +import org.neo4j.graphdb.Transaction; +import org.neo4j.internal.kernel.api.exceptions.KernelException; +import org.neo4j.kernel.impl.proc.Procedures; +import org.neo4j.kernel.internal.GraphDatabaseAPI; + +import java.util.Map; + +import static java.lang.Math.sqrt; +import static java.util.Collections.singletonMap; +import static org.junit.Assert.*; +import static org.neo4j.helpers.collection.MapUtil.map; + +public class EuclideanTest { + + private static GraphDatabaseAPI db; + private Transaction tx; + public static final String STATEMENT_STREAM = "MATCH (i:Item) WITH i ORDER BY id(i) MATCH (p:Person) OPTIONAL MATCH (p)-[r:LIKES]->(i)\n" + + "WITH {item:id(p), weights: collect(coalesce(r.stars,0))} as userData\n" + + "WITH collect(userData) as data\n" + + "call algo.similarity.euclidean.stream(data,$config) " + + "yield item1, item2, count1, count2, intersection, similarity " + + "RETURN * ORDER BY item1,item2"; + + public static final String STATEMENT = "MATCH (i:Item) WITH i ORDER BY id(i) MATCH (p:Person) OPTIONAL MATCH (p)-[r:LIKES]->(i)\n" + + "WITH {item:id(p), weights: collect(coalesce(r.stars,0))} as userData\n" + + "WITH collect(userData) as data\n" + + + "CALL algo.similarity.euclidean(data, $config) " + + "yield p25, p50, p75, p90, p95, p99, p999, p100, nodes, similarityPairs " + + "RETURN *"; + + @BeforeClass + public static void beforeClass() throws KernelException { + db = TestDatabaseCreator.createTestDatabase(); + db.getDependencyResolver().resolveDependency(Procedures.class).registerProcedure(EuclideanProc.class); + db.execute(buildDatabaseQuery()).close(); + } + + @AfterClass + public static void AfterClass() { + db.shutdown(); + } + + @Before + public void setUp() throws Exception { + tx = db.beginTx(); + } + + @After + public void tearDown() throws Exception { + tx.close(); + } + + private static void buildRandomDB(int size) { + db.execute("MATCH (n) DETACH DELETE n").close(); + db.execute("UNWIND range(1,$size/10) as _ CREATE (:Person) CREATE (:Item) ",singletonMap("size",size)).close(); + String statement = + "MATCH (p:Person) WITH collect(p) as people " + + "MATCH (i:Item) WITH people, collect(i) as items " + + "UNWIND range(1,$size) as _ " + + "WITH people[toInteger(rand()*size(people))] as p, items[toInteger(rand()*size(items))] as i " + + "MERGE (p)-[:LIKES]->(i) RETURN count(*) "; + db.execute(statement,singletonMap("size",size)).close(); + } + private static String buildDatabaseQuery() { + return "CREATE (a:Person {name:'Alice'})\n" + + "CREATE (b:Person {name:'Bob'})\n" + + "CREATE (c:Person {name:'Charlie'})\n" + + "CREATE (d:Person {name:'Dana'})\n" + + "CREATE (i1:Item {name:'p1'})\n" + + "CREATE (i2:Item {name:'p2'})\n" + + "CREATE (i3:Item {name:'p3'})\n" + + + "CREATE" + + " (a)-[:LIKES {stars:1}]->(i1),\n" + + " (a)-[:LIKES {stars:2}]->(i2),\n" + + " (a)-[:LIKES {stars:5}]->(i3),\n" + + " (b)-[:LIKES {stars:1}]->(i1),\n" + + " (b)-[:LIKES {stars:3}]->(i2),\n" + + " (c)-[:LIKES {stars:4}]->(i3)\n"; + // a: 1,2,5 + // b: 1,3,0 + // c: 0,0,4 + // a - b: sqrt(26) = 5.1 + // a - c: sqrt(6) = 2.5 + // b - c: sqrt(26) = 5.1 + } + + + @Test + public void euclideanSingleMultiThreadComparision() { + int size = 333; + buildRandomDB(size); + Result result1 = db.execute(STATEMENT_STREAM, map("config", map("similarityCutoff",-0.1,"concurrency", 1))); + Result result2 = db.execute(STATEMENT_STREAM, map("config", map("similarityCutoff",-0.1,"concurrency", 2))); + Result result4 = db.execute(STATEMENT_STREAM, map("config", map("similarityCutoff",-0.1,"concurrency", 4))); + Result result8 = db.execute(STATEMENT_STREAM, map("config", map("similarityCutoff",-0.1,"concurrency", 8))); + int count=0; + while (result1.hasNext()) { + Map row1 = result1.next(); + assertEquals(row1.toString(), row1,result2.next()); + assertEquals(row1.toString(), row1,result4.next()); + assertEquals(row1.toString(), row1,result8.next()); + count++; + } + int people = size/10; + assertEquals((people * people - people)/2,count); + } + + @Test + public void euclideanSingleMultiThreadComparisionTopK() { + int size = 333; + buildRandomDB(size); + + Result result1 = db.execute(STATEMENT_STREAM, map("config", map("similarityCutoff",-0.1,"topK",1,"concurrency", 1))); + Result result2 = db.execute(STATEMENT_STREAM, map("config", map("similarityCutoff",-0.1,"topK",1,"concurrency", 2))); + Result result4 = db.execute(STATEMENT_STREAM, map("config", map("similarityCutoff",-0.1,"topK",1,"concurrency", 4))); + Result result8 = db.execute(STATEMENT_STREAM, map("config", map("similarityCutoff",-0.1,"topK",1,"concurrency", 8))); + int count=0; + while (result1.hasNext()) { + Map row1 = result1.next(); + assertEquals(row1.toString(), row1,result2.next()); + assertEquals(row1.toString(), row1,result4.next()); + assertEquals(row1.toString(), row1,result8.next()); + count++; + } + int people = size/10; + assertEquals(people,count); + } + + @Test + public void topNeuclideanStreamTest() { + Result results = db.execute(STATEMENT_STREAM, map("config",map("top",2))); + assert02(results.next()); + assert13(results.next()); + assertFalse(results.hasNext()); + } + + @Test + public void euclideanStreamTest() { + // System.out.println(db.execute("MATCH (i:Item) WITH i ORDER BY id(i) MATCH (p:Person) OPTIONAL MATCH (p)-[r:LIKES]->(i) RETURN p,r,i").resultAsString()); + // a: 1,2,5 + // b: 1,3,0 + // c: 0,0,4 + // d: 0,0,0 + // a0 - b1: sqrt(26) = 5.1 + // a0 - c2: sqrt(6) = 2.5 + // a0 - d3: sqrt(1+4+25) = 5.5 + // b1 - c2: sqrt(26) = 5.1 + // b1 - d3: sqrt(10) = 3.2 + // c2 - d3: sqrt(16) = 4 + // System.out.println(db.execute(query).resultAsString()); + + Result results = db.execute(STATEMENT_STREAM, map("config",map("concurrency",1))); + assertTrue(results.hasNext()); + assert01(results.next()); + assert02(results.next()); + assert03(results.next()); + assert12(results.next()); + assert13(results.next()); + assert23(results.next()); + assertFalse(results.hasNext()); + } + + @Test + public void topKEuclideanStreamTest() { + Map params = map("config", map( "concurrency", 1,"topK", 1)); + + Result results = db.execute(STATEMENT_STREAM, params); + assertTrue(results.hasNext()); + assert02(results.next()); + assert13(results.next()); + assert02(flip(results.next())); + assert13(flip(results.next())); + assertFalse(results.hasNext()); + } + + private Map flip(Map row) { + return map("similarity", row.get("similarity"),"intersection", row.get("intersection"), + "item1",row.get("item2"),"count1",row.get("count2"), + "item2",row.get("item1"),"count2",row.get("count1")); + } + + private void assertSameSource(Result results, int count, long source) { + Map row; + long target = 0; + for (int i = 0; i params = map("config", map("topK", 4, "concurrency", 4, "similarityCutoff", -0.1)); + System.out.println(db.execute(STATEMENT_STREAM,params).resultAsString()); + + Result results = db.execute(STATEMENT_STREAM,params); + assertSameSource(results, 3, 0L); + assertSameSource(results, 3, 1L); + assertSameSource(results, 3, 2L); + assertSameSource(results, 3, 3L); + assertFalse(results.hasNext()); + } + + @Test + public void topK3euclideanStreamTest() { + // a0 - b1: sqrt(26) = 5.1 + // a0 - c2: sqrt(6) = 2.5 + // a0 - d3: sqrt(1+4+25) = 5.5 + // b1 - c2: sqrt(26) = 5.1 + // b1 - d3: sqrt(10) = 3.2 + // c2 - d3: sqrt(16) = 4 + Map params = map("config", map("concurrency", 3, "topK", 3)); + + System.out.println(db.execute(STATEMENT_STREAM, params).resultAsString()); + + Result results = db.execute(STATEMENT_STREAM, params); + assertSameSource(results, 3, 0L); + assertSameSource(results, 3, 1L); + assertSameSource(results, 3, 2L); + assertSameSource(results, 3, 3L); + assertFalse(results.hasNext()); + } + + @Test + public void simpleEuclideanTest() { + Map params = map("config", map()); + + Map row = db.execute(STATEMENT,params).next(); + assertEquals((double) row.get("p25"), 3.16, 0.01); + assertEquals((double) row.get("p50"), 4.00, 0.01); + assertEquals((double) row.get("p75"), 5.10, 0.01); + assertEquals((double) row.get("p95"), 5.48, 0.01); + assertEquals((double) row.get("p99"), 5.48, 0.01); + assertEquals((double) row.get("p100"), 5.48, 0.01); + } + + @Test + public void simpleEuclideanWriteTest() { + Map params = map("config", map( "write",true, "similarityCutoff", 4.0)); + + db.execute(STATEMENT,params).close(); + + String checkSimilaritiesQuery = "MATCH (a)-[similar:SIMILAR]-(b)" + + "RETURN a.name AS node1, b.name as node2, similar.score AS score " + + "ORDER BY id(a), id(b)"; + + System.out.println(db.execute(checkSimilaritiesQuery).resultAsString()); + Result result = db.execute(checkSimilaritiesQuery); + + // a0 - b1: sqrt(26) = 5.1 + // a0 - c2: sqrt(6) = 2.5 + // a0 - d3: sqrt(1+4+25) = 5.5 + // b1 - c2: sqrt(26) = 5.1 + // b1 - d3: sqrt(10) = 3.2 + // c2 - d3: sqrt(16) = 4 + + assertTrue(result.hasNext()); + Map row = result.next(); + assertEquals(row.get("node1"), "Alice"); + assertEquals(row.get("node2"), "Charlie"); + assertEquals((double) row.get("score"), 2.45, 0.01); + + assertTrue(result.hasNext()); + row = result.next(); + assertEquals(row.get("node1"), "Bob"); + assertEquals(row.get("node2"), "Dana"); + assertEquals((double) row.get("score"), 3.16, 0.01); + + assertTrue(result.hasNext()); + row = result.next(); + assertEquals(row.get("node1"), "Charlie"); + assertEquals(row.get("node2"), "Alice"); + assertEquals((double) row.get("score"), 2.45, 0.01); + + + assertTrue(result.hasNext()); + row = result.next(); + assertEquals(row.get("node1"), "Charlie"); + assertEquals(row.get("node2"), "Dana"); + assertEquals((double) row.get("score"), 4.0, 0.01); + + assertTrue(result.hasNext()); + row = result.next(); + assertEquals(row.get("node1"), "Dana"); + assertEquals(row.get("node2"), "Bob"); + assertEquals((double) row.get("score"), 3.16, 0.01); + + assertTrue(result.hasNext()); + row = result.next(); + assertEquals(row.get("node1"), "Dana"); + assertEquals(row.get("node2"), "Charlie"); + assertEquals((double) row.get("score"), 4.0, 0.01); + + assertFalse(result.hasNext()); + } + + private void assert23(Map row) { + assertEquals(2L, row.get("item1")); + assertEquals(3L, row.get("item2")); + assertEquals(1L, row.get("count1")); + assertEquals(0L, row.get("count2")); + assertEquals(0L, row.get("intersection")); + assertEquals(sqrt(16), row.get("similarity")); + } + + private void assert13(Map row) { + assertEquals(1L, row.get("item1")); + assertEquals(3L, row.get("item2")); + assertEquals(2L, row.get("count1")); + assertEquals(0L, row.get("count2")); + assertEquals(0L, row.get("intersection")); + assertEquals(sqrt(10), row.get("similarity")); + } + + private void assert12(Map row) { + assertEquals(1L, row.get("item1")); + assertEquals(2L, row.get("item2")); + assertEquals(2L, row.get("count1")); + assertEquals(1L, row.get("count2")); + // assertEquals(0L, row.get("intersection")); + assertEquals(sqrt(5*5+1), row.get("similarity")); + } + + private void assert03(Map row) { + assertEquals(0L, row.get("item1")); + assertEquals(3L, row.get("item2")); + assertEquals(3L, row.get("count1")); + assertEquals(0L, row.get("count2")); + assertEquals(0L, row.get("intersection")); + assertEquals(sqrt(5*5+2*2+1), row.get("similarity")); + } + + private void assert02(Map row) { + assertEquals(0L, row.get("item1")); + assertEquals(2L, row.get("item2")); + assertEquals(3L, row.get("count1")); + assertEquals(1L, row.get("count2")); + // assertEquals(1L, row.get("intersection")); + assertEquals(sqrt(6), row.get("similarity")); + } + + private void assert01(Map row) { + assertEquals(0L, row.get("item1")); + assertEquals(1L, row.get("item2")); + assertEquals(3L, row.get("count1")); + assertEquals(2L, row.get("count2")); + // assertEquals(2L, row.get("intersection")); + assertEquals(sqrt(5*5+1), row.get("similarity")); + } +} diff --git a/tests/src/test/java/org/neo4j/graphalgo/algo/similarity/JaccardTest.java b/tests/src/test/java/org/neo4j/graphalgo/algo/similarity/JaccardTest.java new file mode 100644 index 000000000..13700e7de --- /dev/null +++ b/tests/src/test/java/org/neo4j/graphalgo/algo/similarity/JaccardTest.java @@ -0,0 +1,313 @@ +/** + * Copyright (c) 2017 "Neo4j, Inc." + * + * This file is part of Neo4j Graph Algorithms . + * + * Neo4j Graph Algorithms is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.graphalgo.algo.similarity; + +import org.junit.*; +import org.neo4j.graphalgo.TestDatabaseCreator; +import org.neo4j.graphalgo.similarity.JaccardProc; +import org.neo4j.graphdb.Result; +import org.neo4j.graphdb.Transaction; +import org.neo4j.internal.kernel.api.exceptions.KernelException; +import org.neo4j.kernel.impl.proc.Procedures; +import org.neo4j.kernel.internal.GraphDatabaseAPI; + +import java.util.Map; + +import static java.lang.Math.sqrt; +import static java.util.Collections.singletonMap; +import static org.junit.Assert.*; +import static org.neo4j.helpers.collection.MapUtil.map; + +public class JaccardTest { + + private static GraphDatabaseAPI db; + private Transaction tx; + public static final String STATEMENT_STREAM = "MATCH (p:Person)-[:LIKES]->(i:Item) \n" + + "WITH {item:id(p), categories: collect(distinct id(i))} as userData\n" + + "WITH collect(userData) as data\n" + + "call algo.similarity.jaccard.stream(data,$config) " + + "yield item1, item2, count1, count2, intersection, similarity " + + "RETURN * ORDER BY item1,item2"; + + public static final String STATEMENT = "MATCH (p:Person)-[:LIKES]->(i:Item) \n" + + "WITH {item:id(p), categories: collect(distinct id(i))} as userData\n" + + "WITH collect(userData) as data\n" + + "CALL algo.similarity.jaccard(data, $config) " + + "yield p25, p50, p75, p90, p95, p99, p999, p100, nodes, similarityPairs " + + "RETURN *"; + + @BeforeClass + public static void beforeClass() throws KernelException { + db = TestDatabaseCreator.createTestDatabase(); + db.getDependencyResolver().resolveDependency(Procedures.class).registerProcedure(JaccardProc.class); + db.execute(buildDatabaseQuery()).close(); + } + + @AfterClass + public static void AfterClass() { + db.shutdown(); + } + + @Before + public void setUp() throws Exception { + tx = db.beginTx(); + } + + @After + public void tearDown() throws Exception { + tx.close(); + } + + private static void buildRandomDB(int size) { + db.execute("MATCH (n) DETACH DELETE n").close(); + db.execute("UNWIND range(1,$size/10) as _ CREATE (:Person) CREATE (:Item) ",singletonMap("size",size)).close(); + String statement = + "MATCH (p:Person) WITH collect(p) as people " + + "MATCH (i:Item) WITH people, collect(i) as items " + + "UNWIND range(1,$size) as _ " + + "WITH people[toInteger(rand()*size(people))] as p, items[toInteger(rand()*size(items))] as i " + + "MERGE (p)-[:LIKES]->(i) RETURN count(*) "; + db.execute(statement,singletonMap("size",size)).close(); + } + private static String buildDatabaseQuery() { + return "CREATE (a:Person {name:'Alice'})\n" + + "CREATE (b:Person {name:'Bob'})\n" + + "CREATE (c:Person {name:'Charlie'})\n" + + "CREATE (d:Person {name:'Dana'})\n" + + "CREATE (i1:Item {name:'p1'})\n" + + "CREATE (i2:Item {name:'p2'})\n" + + "CREATE (i3:Item {name:'p3'})\n" + + + "CREATE" + + " (a)-[:LIKES]->(i1),\n" + + " (a)-[:LIKES]->(i2),\n" + + " (a)-[:LIKES]->(i3),\n" + + " (b)-[:LIKES]->(i1),\n" + + " (b)-[:LIKES]->(i2),\n" + + " (c)-[:LIKES]->(i3)\n"; + // a: 3 + // b: 2 + // c: 1 + // a / b = 2 : 2/3 + // a / c = 1 : 1/3 + // b / c = 0 : 0/3 = 0 + } + + + @Test + public void jaccardSingleMultiThreadComparision() { + int size = 333; + buildRandomDB(size); + Result result1 = db.execute(STATEMENT_STREAM, map("config", map("similarityCutoff",-0.1,"concurrency", 1))); + Result result2 = db.execute(STATEMENT_STREAM, map("config", map("similarityCutoff",-0.1,"concurrency", 2))); + Result result4 = db.execute(STATEMENT_STREAM, map("config", map("similarityCutoff",-0.1,"concurrency", 4))); + Result result8 = db.execute(STATEMENT_STREAM, map("config", map("similarityCutoff",-0.1,"concurrency", 8))); + int count=0; + while (result1.hasNext()) { + Map row1 = result1.next(); + assertEquals(row1.toString(), row1,result2.next()); + assertEquals(row1.toString(), row1,result4.next()); + assertEquals(row1.toString(), row1,result8.next()); + count++; + } + int people = size/10; + assertEquals((people * people - people)/2,count); + } + + @Test + public void jaccardSingleMultiThreadComparisionTopK() { + int size = 333; + buildRandomDB(size); + + Result result1 = db.execute(STATEMENT_STREAM, map("config", map("similarityCutoff",-0.1,"topK",1,"concurrency", 1))); + Result result2 = db.execute(STATEMENT_STREAM, map("config", map("similarityCutoff",-0.1,"topK",1,"concurrency", 2))); + Result result4 = db.execute(STATEMENT_STREAM, map("config", map("similarityCutoff",-0.1,"topK",1,"concurrency", 4))); + Result result8 = db.execute(STATEMENT_STREAM, map("config", map("similarityCutoff",-0.1,"topK",1,"concurrency", 8))); + int count=0; + while (result1.hasNext()) { + Map row1 = result1.next(); + assertEquals(row1.toString(), row1,result2.next()); + assertEquals(row1.toString(), row1,result4.next()); + assertEquals(row1.toString(), row1,result8.next()); + count++; + } + int people = size/10; + assertEquals(people,count); + } + + @Test + public void topNjaccardStreamTest() { + Result results = db.execute(STATEMENT_STREAM, map("config",map("top",2))); + assert01(results.next()); + assert02(results.next()); + assertFalse(results.hasNext()); + } + + @Test + public void jaccardStreamTest() { + Result results = db.execute(STATEMENT_STREAM, map("config",map("concurrency",1))); + assertTrue(results.hasNext()); + assert01(results.next()); + assert02(results.next()); + assert12(results.next()); + assertFalse(results.hasNext()); + } + + @Test + public void topKJaccardStreamTest() { + Map params = map("config", map( "concurrency", 1,"topK", 1)); + System.out.println(db.execute(STATEMENT_STREAM, params).resultAsString()); + + Result results = db.execute(STATEMENT_STREAM, params); + assertTrue(results.hasNext()); + assert01(results.next()); + assert01(flip(results.next())); + assert02(flip(results.next())); + assertFalse(results.hasNext()); + } + + private Map flip(Map row) { + return map("similarity", row.get("similarity"),"intersection", row.get("intersection"), + "item1",row.get("item2"),"count1",row.get("count2"), + "item2",row.get("item1"),"count2",row.get("count1")); + } + + private void assertSameSource(Result results, int count, long source) { + Map row; + long target = 0; + for (int i = 0; i params = map("config", map("topK", 4, "concurrency", 4, "similarityCutoff", -0.1)); + System.out.println(db.execute(STATEMENT_STREAM,params).resultAsString()); + + Result results = db.execute(STATEMENT_STREAM,params); + assertSameSource(results, 2, 0L); + assertSameSource(results, 2, 1L); + assertSameSource(results, 2, 2L); + assertFalse(results.hasNext()); + } + + @Test + public void topK3jaccardStreamTest() { + Map params = map("config", map("concurrency", 3, "topK", 3)); + + System.out.println(db.execute(STATEMENT_STREAM, params).resultAsString()); + + Result results = db.execute(STATEMENT_STREAM, params); + assertSameSource(results, 2, 0L); + assertSameSource(results, 2, 1L); + assertSameSource(results, 2, 2L); + assertFalse(results.hasNext()); + } + + @Test + public void simpleJaccardTest() { + Map params = map("config", map("similarityCutoff", 0.0)); + + Map row = db.execute(STATEMENT,params).next(); + assertEquals((double) row.get("p25"), 0.33, 0.01); + assertEquals((double) row.get("p50"), 0.33, 0.01); + assertEquals((double) row.get("p75"), 0.66, 0.01); + assertEquals((double) row.get("p95"), 0.66, 0.01); + assertEquals((double) row.get("p99"), 0.66, 0.01); + assertEquals((double) row.get("p100"), 0.66, 0.01); + } + + @Test + public void simpleJaccardWriteTest() { + Map params = map("config", map( "write",true, "similarityCutoff", 0.1)); + + db.execute(STATEMENT,params).close(); + + String checkSimilaritiesQuery = "MATCH (a)-[similar:SIMILAR]-(b)" + + "RETURN a.name AS node1, b.name as node2, similar.score AS score " + + "ORDER BY id(a), id(b)"; + + System.out.println(db.execute(checkSimilaritiesQuery).resultAsString()); + Result result = db.execute(checkSimilaritiesQuery); + + assertTrue(result.hasNext()); + Map row = result.next(); + assertEquals(row.get("node1"), "Alice"); + assertEquals(row.get("node2"), "Bob"); + assertEquals((double) row.get("score"), 0.66, 0.01); + + assertTrue(result.hasNext()); + row = result.next(); + assertEquals(row.get("node1"), "Alice"); + assertEquals(row.get("node2"), "Charlie"); + assertEquals((double) row.get("score"), 0.33, 0.01); + + assertTrue(result.hasNext()); + row = result.next(); + assertEquals(row.get("node1"), "Bob"); + assertEquals(row.get("node2"), "Alice"); + assertEquals((double) row.get("score"), 0.66, 0.01); + + assertTrue(result.hasNext()); + row = result.next(); + assertEquals(row.get("node1"), "Charlie"); + assertEquals(row.get("node2"), "Alice"); + assertEquals((double) row.get("score"), 0.33, 0.01); + + assertFalse(result.hasNext()); + } + + private void assert12(Map row) { + assertEquals(1L, row.get("item1")); + assertEquals(2L, row.get("item2")); + assertEquals(2L, row.get("count1")); + assertEquals(1L, row.get("count2")); + // assertEquals(0L, row.get("intersection")); + assertEquals(0d, row.get("similarity")); + } + + // a / b = 2 : 2/3 + // a / c = 1 : 1/3 + // b / c = 0 : 0/3 = 0 + + private void assert02(Map row) { + assertEquals(0L, row.get("item1")); + assertEquals(2L, row.get("item2")); + assertEquals(3L, row.get("count1")); + assertEquals(1L, row.get("count2")); + // assertEquals(1L, row.get("intersection")); + assertEquals(1d/3d, row.get("similarity")); + } + + private void assert01(Map row) { + assertEquals(0L, row.get("item1")); + assertEquals(1L, row.get("item2")); + assertEquals(3L, row.get("count1")); + assertEquals(2L, row.get("count2")); + // assertEquals(2L, row.get("intersection")); + assertEquals(2d/3d, row.get("similarity")); + } +} diff --git a/tests/src/test/java/org/neo4j/graphalgo/algo/SimilarityTest.java b/tests/src/test/java/org/neo4j/graphalgo/algo/similarity/SimilaritiesTest.java similarity index 98% rename from tests/src/test/java/org/neo4j/graphalgo/algo/SimilarityTest.java rename to tests/src/test/java/org/neo4j/graphalgo/algo/similarity/SimilaritiesTest.java index 0e3924cb4..dcbee4158 100644 --- a/tests/src/test/java/org/neo4j/graphalgo/algo/SimilarityTest.java +++ b/tests/src/test/java/org/neo4j/graphalgo/algo/similarity/SimilaritiesTest.java @@ -16,12 +16,12 @@ * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ -package org.neo4j.graphalgo.algo; +package org.neo4j.graphalgo.algo.similarity; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; -import org.neo4j.graphalgo.Similarity; +import org.neo4j.graphalgo.similarity.Similarities; import org.neo4j.graphalgo.TestDatabaseCreator; import org.neo4j.graphdb.Result; import org.neo4j.graphdb.Transaction; @@ -30,7 +30,7 @@ import static org.junit.Assert.assertEquals; -public class SimilarityTest { +public class SimilaritiesTest { private static final String SETUP = "create (java:Skill{name:'Java'})\n" + "create (neo4j:Skill{name:'Neo4j'})\n" + "create (nodejs:Skill{name:'NodeJS'})\n" + @@ -62,7 +62,7 @@ public static void setUp() throws Exception { db.getDependencyResolver() .resolveDependency(Procedures.class) - .registerFunction(Similarity.class); + .registerFunction(Similarities.class); db.execute(SETUP).close(); } From 57aea9e863b4f7dd1b13a8456da3f4ff0e074e0a Mon Sep 17 00:00:00 2001 From: Michael Hunger Date: Tue, 18 Sep 2018 13:23:35 +0200 Subject: [PATCH 2/2] Added DirectIdMap --- .../org/neo4j/graphalgo/api/GraphFactory.java | 10 +- .../org/neo4j/graphalgo/core/DirectIdMap.java | 85 ++++++++ .../java/org/neo4j/graphalgo/core/IdMap.java | 202 +---------------- .../neo4j/graphalgo/core/MappingIdMap.java | 205 ++++++++++++++++++ .../neo4j/graphalgo/core/NodeImporter.java | 10 +- .../heavyweight/HeavyCypherGraphFactory.java | 12 +- .../core/heavyweight/HeavyGraph.java | 8 +- .../core/heavyweight/HeavyGraphFactory.java | 4 +- .../heavyweight/RelationshipImporter.java | 7 +- .../core/heavyweight/VisitRelationship.java | 16 +- .../graphalgo/core/utils/ParallelUtil.java | 6 +- .../neo4j/graphalgo/core/DirectIdMapTest.java | 116 ++++++++++ .../org/neo4j/graphalgo/core/IdMapTest.java | 13 +- 13 files changed, 446 insertions(+), 248 deletions(-) create mode 100644 core/src/main/java/org/neo4j/graphalgo/core/DirectIdMap.java create mode 100644 core/src/main/java/org/neo4j/graphalgo/core/MappingIdMap.java create mode 100644 tests/src/test/java/org/neo4j/graphalgo/core/DirectIdMapTest.java diff --git a/core/src/main/java/org/neo4j/graphalgo/api/GraphFactory.java b/core/src/main/java/org/neo4j/graphalgo/api/GraphFactory.java index 6c231cdb6..a2394b374 100644 --- a/core/src/main/java/org/neo4j/graphalgo/api/GraphFactory.java +++ b/core/src/main/java/org/neo4j/graphalgo/api/GraphFactory.java @@ -18,13 +18,7 @@ */ package org.neo4j.graphalgo.api; -import org.neo4j.graphalgo.core.GraphDimensions; -import org.neo4j.graphalgo.core.HugeNullWeightMap; -import org.neo4j.graphalgo.core.HugeWeightMap; -import org.neo4j.graphalgo.core.IdMap; -import org.neo4j.graphalgo.core.NodeImporter; -import org.neo4j.graphalgo.core.NullWeightMap; -import org.neo4j.graphalgo.core.WeightMap; +import org.neo4j.graphalgo.core.*; import org.neo4j.graphalgo.core.huge.HugeIdMap; import org.neo4j.graphalgo.core.huge.HugeNodeImporter; import org.neo4j.graphalgo.core.utils.ApproximatedImportProgress; @@ -88,7 +82,7 @@ protected ImportProgress importProgress( ); } - protected IdMap loadIdMap() { + protected MappingIdMap loadIdMap() { final NodeImporter nodeImporter = new NodeImporter( api, setup.tracker, diff --git a/core/src/main/java/org/neo4j/graphalgo/core/DirectIdMap.java b/core/src/main/java/org/neo4j/graphalgo/core/DirectIdMap.java new file mode 100644 index 000000000..e91cc16ce --- /dev/null +++ b/core/src/main/java/org/neo4j/graphalgo/core/DirectIdMap.java @@ -0,0 +1,85 @@ +package org.neo4j.graphalgo.core; + +import org.neo4j.collection.primitive.PrimitiveIntIterable; +import org.neo4j.collection.primitive.PrimitiveIntIterator; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.function.IntPredicate; + +public class DirectIdMap implements IdMap { + private final int size; + + public DirectIdMap(int size) { + this.size = size; + } + + @Override + public Collection batchIterables(int batchSize) { + if (batchSize <= 0) throw new IllegalArgumentException("Invalid batch size: "+batchSize); + List result = new ArrayList<>(size / batchSize + 1); + for (int start = 0; start < size; start +=batchSize) { + result.add(new IntIterable(start,Math.min(start+batchSize,size))); + } + return result; + } + + @Override + public int toMappedNodeId(long nodeId) { + return Math.toIntExact(nodeId); + } + + @Override + public long toOriginalNodeId(int nodeId) { + return nodeId; + } + + @Override + public boolean contains(long nodeId) { + return nodeId >= 0 && nodeId < size; + } + + @Override + public long nodeCount() { + return size; + } + + @Override + public void forEachNode(IntPredicate consumer) { + for (int node=0;node < size; node++) { + if (!consumer.test(node)) return; + } + } + + @Override + public PrimitiveIntIterator nodeIterator() { + return new IntIterable(0,size).iterator(); + } + + private static class IntIterable implements PrimitiveIntIterable { + private final int start; + private final int end; + + public IntIterable(int start,int end) { + this.start = start; + this.end = end; + } + + @Override + public PrimitiveIntIterator iterator() { + return new PrimitiveIntIterator() { + int curr = start; + @Override + public boolean hasNext() { + return curr < end; + } + + @Override + public int next() { + return curr++; + } + }; + } + } +} diff --git a/core/src/main/java/org/neo4j/graphalgo/core/IdMap.java b/core/src/main/java/org/neo4j/graphalgo/core/IdMap.java index e10411ab6..b0347c11a 100644 --- a/core/src/main/java/org/neo4j/graphalgo/core/IdMap.java +++ b/core/src/main/java/org/neo4j/graphalgo/core/IdMap.java @@ -1,208 +1,8 @@ -/** - * Copyright (c) 2017 "Neo4j, Inc." - * - * This file is part of Neo4j Graph Algorithms . - * - * Neo4j Graph Algorithms is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ package org.neo4j.graphalgo.core; -import com.carrotsearch.hppc.LongIntHashMap; -import com.carrotsearch.hppc.LongIntMap; -import com.carrotsearch.hppc.cursors.LongIntCursor; -import org.neo4j.collection.primitive.PrimitiveIntIterable; -import org.neo4j.collection.primitive.PrimitiveIntIterator; import org.neo4j.graphalgo.api.BatchNodeIterable; import org.neo4j.graphalgo.api.IdMapping; import org.neo4j.graphalgo.api.NodeIterator; -import org.neo4j.graphalgo.core.utils.ParallelUtil; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.function.IntPredicate; - -/** - * This is basically a long to int mapper. It sorts the id's in ascending order so its - * guaranteed that there is no ID greater then nextGraphId / capacity - */ -public final class IdMap implements IdMapping, NodeIterator, BatchNodeIterable { - - private final IdIterator iter; - private int nextGraphId; - private long[] graphIds; - private LongIntMap nodeToGraphIds; - - /** - * initialize the map with maximum node capacity - */ - public IdMap(final int capacity) { - nodeToGraphIds = new LongIntHashMap((int) Math.ceil(capacity / 0.99), 0.99); - iter = new IdIterator(); - } - - /** - * CTor used by deserializing logic - */ - public IdMap( - long[] graphIds, - LongIntMap nodeToGraphIds) { - this.nextGraphId = graphIds.length; - this.graphIds = graphIds; - this.nodeToGraphIds = nodeToGraphIds; - iter = new IdIterator(); - } - - public PrimitiveIntIterator iterator() { - return iter.reset(nextGraphId); - } - - public int mapOrGet(long longValue) { - int intValue = nodeToGraphIds.getOrDefault(longValue, -1); - if (intValue == -1) { - intValue = nextGraphId++; - nodeToGraphIds.put(longValue, intValue); - } - return intValue; - } - - public void add(long longValue) { - int intValue = nextGraphId++; - nodeToGraphIds.put(longValue, intValue); - } - - public int get(long longValue) { - return nodeToGraphIds.getOrDefault(longValue, -1); - } - - public void buildMappedIds() { - graphIds = new long[size()]; - for (final LongIntCursor cursor : nodeToGraphIds) { - graphIds[cursor.value] = cursor.key; - } - } - - public int size() { - return nextGraphId; - } - - public long[] mappedIds() { - return graphIds; - } - - public LongIntMap nodeToGraphIds() { - return nodeToGraphIds; - } - - public void forEach(IntPredicate consumer) { - int limit = this.nextGraphId; - for (int i = 0; i < limit; i++) { - if (!consumer.test(i)) { - return; - } - } - } - - @Override - public int toMappedNodeId(long nodeId) { - return mapOrGet(nodeId); - } - - @Override - public long toOriginalNodeId(int nodeId) { - return graphIds[nodeId]; - } - - @Override - public boolean contains(final long nodeId) { - return nodeToGraphIds.containsKey(nodeId); - } - - @Override - public long nodeCount() { - return graphIds.length; - } - - @Override - public void forEachNode(IntPredicate consumer) { - final int count = graphIds.length; - for (int i = 0; i < count; i++) { - if (!consumer.test(i)) { - return; - } - } - } - - @Override - public PrimitiveIntIterator nodeIterator() { - return new IdIterator().reset(graphIds.length); - } - - @Override - public Collection batchIterables(int batchSize) { - int nodeCount = graphIds.length; - int numberOfBatches = ParallelUtil.threadSize(batchSize, nodeCount); - if (numberOfBatches == 1) { - return Collections.singleton(this::nodeIterator); - } - PrimitiveIntIterable[] iterators = new PrimitiveIntIterable[numberOfBatches]; - Arrays.setAll(iterators, i -> { - int start = i * batchSize; - int length = Math.min(batchSize, nodeCount - start); - return new IdIterable(start, length); - }); - return Arrays.asList(iterators); - } - - public static final class IdIterable implements PrimitiveIntIterable { - private final int start; - private final int length; - - public IdIterable(int start, int length) { - this.start = start; - this.length = length; - } - - @Override - public PrimitiveIntIterator iterator() { - return new IdIterator().reset(start, length); - } - } - - private static final class IdIterator implements PrimitiveIntIterator { - - private int current; - private int limit; // exclusive upper bound - - private PrimitiveIntIterator reset(int length) { - return reset(0, length); - } - - private PrimitiveIntIterator reset(int start, int length) { - current = start; - this.limit = start + length; - return this; - } - - @Override - public boolean hasNext() { - return current < limit; - } - - @Override - public int next() { - return current++; - } - } +public interface IdMap extends IdMapping, NodeIterator, BatchNodeIterable { } diff --git a/core/src/main/java/org/neo4j/graphalgo/core/MappingIdMap.java b/core/src/main/java/org/neo4j/graphalgo/core/MappingIdMap.java new file mode 100644 index 000000000..a5e103036 --- /dev/null +++ b/core/src/main/java/org/neo4j/graphalgo/core/MappingIdMap.java @@ -0,0 +1,205 @@ +/** + * Copyright (c) 2017 "Neo4j, Inc." + * + * This file is part of Neo4j Graph Algorithms . + * + * Neo4j Graph Algorithms is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.graphalgo.core; + +import com.carrotsearch.hppc.LongIntHashMap; +import com.carrotsearch.hppc.LongIntMap; +import com.carrotsearch.hppc.cursors.LongIntCursor; +import org.neo4j.collection.primitive.PrimitiveIntIterable; +import org.neo4j.collection.primitive.PrimitiveIntIterator; +import org.neo4j.graphalgo.core.utils.ParallelUtil; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.function.IntPredicate; + +/** + * This is basically a long to int mapper. It sorts the id's in ascending order so its + * guaranteed that there is no ID greater then nextGraphId / capacity + */ +public final class MappingIdMap implements IdMap { + + private final IdIterator iter; + private int nextGraphId; + private long[] graphIds; + private LongIntMap nodeToGraphIds; + + /** + * initialize the map with maximum node capacity + */ + public MappingIdMap(final int capacity) { + nodeToGraphIds = new LongIntHashMap((int) Math.ceil(capacity / 0.99), 0.99); + iter = new IdIterator(); + } + + /** + * CTor used by deserializing logic + */ + public MappingIdMap( + long[] graphIds, + LongIntMap nodeToGraphIds) { + this.nextGraphId = graphIds.length; + this.graphIds = graphIds; + this.nodeToGraphIds = nodeToGraphIds; + iter = new IdIterator(); + } + + public PrimitiveIntIterator iterator() { + return iter.reset(nextGraphId); + } + + public int mapOrGet(long longValue) { + int intValue = nodeToGraphIds.getOrDefault(longValue, -1); + if (intValue == -1) { + intValue = nextGraphId++; + nodeToGraphIds.put(longValue, intValue); + } + return intValue; + } + + public void add(long longValue) { + int intValue = nextGraphId++; + nodeToGraphIds.put(longValue, intValue); + } + + public void buildMappedIds() { + graphIds = new long[size()]; + for (final LongIntCursor cursor : nodeToGraphIds) { + graphIds[cursor.value] = cursor.key; + } + } + + @Override + public long nodeCount() { + return graphIds.length; + } + + public int size() { + return nextGraphId; + } + + public long[] mappedIds() { + return graphIds; + } + + public LongIntMap nodeToGraphIds() { + return nodeToGraphIds; + } + + @Override + public int toMappedNodeId(long nodeId) { + return mapOrGet(nodeId); + } + + public int get(long longValue) { + return nodeToGraphIds.getOrDefault(longValue, -1); + } + + @Override + public long toOriginalNodeId(int nodeId) { + return graphIds[nodeId]; + } + + @Override + public boolean contains(final long nodeId) { + return nodeToGraphIds.containsKey(nodeId); + } + + @Override + public void forEachNode(IntPredicate consumer) { + final int count = graphIds.length; + for (int i = 0; i < count; i++) { + if (!consumer.test(i)) { + return; + } + } + } + + public void forEach(IntPredicate consumer) { + int limit = this.nextGraphId; + for (int i = 0; i < limit; i++) { + if (!consumer.test(i)) { + return; + } + } + } + + @Override + public PrimitiveIntIterator nodeIterator() { + return new IdIterator().reset(graphIds.length); + } + + @Override + public Collection batchIterables(int batchSize) { + int nodeCount = graphIds.length; + int numberOfBatches = ParallelUtil.threadSize(batchSize, nodeCount); + if (numberOfBatches == 1) { + return Collections.singleton(this::nodeIterator); + } + PrimitiveIntIterable[] iterators = new PrimitiveIntIterable[numberOfBatches]; + Arrays.setAll(iterators, i -> { + int start = i * batchSize; + int length = Math.min(batchSize, nodeCount - start); + return new IdIterable(start, length); + }); + return Arrays.asList(iterators); + } + + public static final class IdIterable implements PrimitiveIntIterable { + private final int start; + private final int length; + + public IdIterable(int start, int length) { + this.start = start; + this.length = length; + } + + @Override + public PrimitiveIntIterator iterator() { + return new IdIterator().reset(start, length); + } + } + + private static final class IdIterator implements PrimitiveIntIterator { + + private int current; + private int limit; // exclusive upper bound + + private PrimitiveIntIterator reset(int length) { + return reset(0, length); + } + + private PrimitiveIntIterator reset(int start, int length) { + current = start; + this.limit = start + length; + return this; + } + + @Override + public boolean hasNext() { + return current < limit; + } + + @Override + public int next() { + return current++; + } + } +} diff --git a/core/src/main/java/org/neo4j/graphalgo/core/NodeImporter.java b/core/src/main/java/org/neo4j/graphalgo/core/NodeImporter.java index 8e9c8538b..8a5ac9f59 100644 --- a/core/src/main/java/org/neo4j/graphalgo/core/NodeImporter.java +++ b/core/src/main/java/org/neo4j/graphalgo/core/NodeImporter.java @@ -22,7 +22,7 @@ import org.neo4j.graphalgo.core.utils.paged.AllocationTracker; import org.neo4j.kernel.internal.GraphDatabaseAPI; -public final class NodeImporter extends BaseNodeImporter { +public final class NodeImporter extends BaseNodeImporter { private final AllocationTracker tracker; @@ -37,17 +37,17 @@ public NodeImporter( } @Override - protected IdMap newNodeMap(final long nodeCount) { - return new IdMap((int) nodeCount); + protected MappingIdMap newNodeMap(final long nodeCount) { + return new MappingIdMap((int) nodeCount); } @Override - protected void addNodeId(final IdMap map, final long nodeId) { + protected void addNodeId(final MappingIdMap map, final long nodeId) { map.add(nodeId); } @Override - protected void finish(final IdMap map) { + protected void finish(final MappingIdMap map) { map.buildMappedIds(); } } diff --git a/core/src/main/java/org/neo4j/graphalgo/core/heavyweight/HeavyCypherGraphFactory.java b/core/src/main/java/org/neo4j/graphalgo/core/heavyweight/HeavyCypherGraphFactory.java index a0642e00d..4fdc0be7b 100644 --- a/core/src/main/java/org/neo4j/graphalgo/core/heavyweight/HeavyCypherGraphFactory.java +++ b/core/src/main/java/org/neo4j/graphalgo/core/heavyweight/HeavyCypherGraphFactory.java @@ -26,7 +26,7 @@ import org.neo4j.graphalgo.api.GraphFactory; import org.neo4j.graphalgo.api.GraphSetup; import org.neo4j.graphalgo.api.WeightMapping; -import org.neo4j.graphalgo.core.IdMap; +import org.neo4j.graphalgo.core.MappingIdMap; import org.neo4j.graphalgo.core.NullWeightMap; import org.neo4j.graphalgo.core.WeightMap; import org.neo4j.graphalgo.core.utils.RawValues; @@ -63,7 +63,7 @@ public HeavyCypherGraphFactory( static class Nodes { private final long offset; private final long rows; - IdMap idMap; + MappingIdMap idMap; WeightMap nodeWeights; WeightMap nodeProps; private final Map nodeProperties; @@ -73,7 +73,7 @@ static class Nodes { Nodes( long offset, long rows, - IdMap idMap, + MappingIdMap idMap, WeightMap nodeWeights, WeightMap nodeProps, Map nodeProperties, @@ -279,7 +279,7 @@ private Nodes batchLoadNodes(int batchSize) { return new Nodes( 0L, total, - new IdMap(graphIds,nodeToGraphIds), + new MappingIdMap(graphIds,nodeToGraphIds), null,null, nodeProperties, setup.nodeDefaultWeight, @@ -305,7 +305,7 @@ private boolean canBatchLoad(int batchSize, String statement) { private Relationships loadRelationships(long offset, int batchSize, Nodes nodes) { - IdMap idMap = nodes.idMap; + MappingIdMap idMap = nodes.idMap; int nodeCount = idMap.size(); int capacity = batchSize == NO_BATCH ? nodeCount : batchSize; @@ -357,7 +357,7 @@ public boolean visit(Result.ResultRow row) throws RuntimeException { private Nodes loadNodes(long offset, int batchSize) { int capacity = batchSize == NO_BATCH ? INITIAL_NODE_COUNT : batchSize; - final IdMap idMap = new IdMap(capacity); + final MappingIdMap idMap = new MappingIdMap(capacity); Map nodeProperties = new HashMap<>(); for (PropertyMapping propertyMapping : setup.nodePropertyMappings) { diff --git a/core/src/main/java/org/neo4j/graphalgo/core/heavyweight/HeavyGraph.java b/core/src/main/java/org/neo4j/graphalgo/core/heavyweight/HeavyGraph.java index 73eeb0725..a9b95d480 100644 --- a/core/src/main/java/org/neo4j/graphalgo/core/heavyweight/HeavyGraph.java +++ b/core/src/main/java/org/neo4j/graphalgo/core/heavyweight/HeavyGraph.java @@ -59,17 +59,17 @@ public HeavyGraph( @Override public long nodeCount() { - return nodeIdMap.size(); + return nodeIdMap.nodeCount(); } @Override public void forEachNode(IntPredicate consumer) { - nodeIdMap.forEach(consumer); + nodeIdMap.forEachNode(consumer); } @Override public PrimitiveIntIterator nodeIterator() { - return nodeIdMap.iterator(); + return nodeIdMap.nodeIterator(); } @Override @@ -97,7 +97,7 @@ public void forEachRelationship( @Override public int toMappedNodeId(long originalNodeId) { - return nodeIdMap.get(originalNodeId); + return nodeIdMap.toMappedNodeId(originalNodeId); } @Override diff --git a/core/src/main/java/org/neo4j/graphalgo/core/heavyweight/HeavyGraphFactory.java b/core/src/main/java/org/neo4j/graphalgo/core/heavyweight/HeavyGraphFactory.java index 6ce6a7ebb..da88b4449 100644 --- a/core/src/main/java/org/neo4j/graphalgo/core/heavyweight/HeavyGraphFactory.java +++ b/core/src/main/java/org/neo4j/graphalgo/core/heavyweight/HeavyGraphFactory.java @@ -24,7 +24,7 @@ import org.neo4j.graphalgo.api.GraphSetup; import org.neo4j.graphalgo.api.WeightMapping; import org.neo4j.graphalgo.core.IdMap; -import org.neo4j.graphalgo.core.WeightMap; +import org.neo4j.graphalgo.core.MappingIdMap; import org.neo4j.graphalgo.core.utils.ParallelUtil; import org.neo4j.kernel.internal.GraphDatabaseAPI; @@ -48,7 +48,7 @@ public Graph build() { } private Graph importGraph(final int batchSize) { - final IdMap idMap = loadIdMap(); + final MappingIdMap idMap = loadIdMap(); final Supplier relWeights = () -> newWeightMap( dimensions.relWeightId(), diff --git a/core/src/main/java/org/neo4j/graphalgo/core/heavyweight/RelationshipImporter.java b/core/src/main/java/org/neo4j/graphalgo/core/heavyweight/RelationshipImporter.java index daea3c581..456be5378 100644 --- a/core/src/main/java/org/neo4j/graphalgo/core/heavyweight/RelationshipImporter.java +++ b/core/src/main/java/org/neo4j/graphalgo/core/heavyweight/RelationshipImporter.java @@ -27,6 +27,7 @@ import org.neo4j.graphalgo.api.WeightMapping; import org.neo4j.graphalgo.core.GraphDimensions; import org.neo4j.graphalgo.core.IdMap; +import org.neo4j.graphalgo.core.MappingIdMap; import org.neo4j.graphalgo.core.WeightMap; import org.neo4j.graphalgo.core.utils.ImportProgress; import org.neo4j.graphalgo.core.utils.StatementAction; @@ -37,10 +38,8 @@ import org.neo4j.kernel.internal.GraphDatabaseAPI; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.function.Supplier; -import java.util.stream.Stream; final class RelationshipImporter extends StatementAction { @@ -53,7 +52,7 @@ final class RelationshipImporter extends StatementAction { private final int nodeSize; private final int nodeOffset; - private IdMap idMap; + private MappingIdMap idMap; private AdjacencyMatrix matrix; private WeightMapping relWeights; @@ -66,7 +65,7 @@ final class RelationshipImporter extends StatementAction { ImportProgress progress, int batchSize, int nodeOffset, - IdMap idMap, + MappingIdMap idMap, AdjacencyMatrix matrix, PrimitiveIntIterable nodes, Supplier relWeights, diff --git a/core/src/main/java/org/neo4j/graphalgo/core/heavyweight/VisitRelationship.java b/core/src/main/java/org/neo4j/graphalgo/core/heavyweight/VisitRelationship.java index cb04f7207..0ffb02066 100644 --- a/core/src/main/java/org/neo4j/graphalgo/core/heavyweight/VisitRelationship.java +++ b/core/src/main/java/org/neo4j/graphalgo/core/heavyweight/VisitRelationship.java @@ -18,7 +18,7 @@ */ package org.neo4j.graphalgo.core.heavyweight; -import org.neo4j.graphalgo.core.IdMap; +import org.neo4j.graphalgo.core.MappingIdMap; import org.neo4j.graphalgo.core.WeightMap; import org.neo4j.graphalgo.core.loading.ReadHelper; import org.neo4j.graphalgo.core.utils.RawValues; @@ -33,7 +33,7 @@ abstract class VisitRelationship { - private final IdMap idMap; + private final MappingIdMap idMap; private final boolean shouldSort; private int[] targets; @@ -44,7 +44,7 @@ abstract class VisitRelationship { int prevTarget; int sourceGraphId; - VisitRelationship(final IdMap idMap, final boolean shouldSort) { + VisitRelationship(final MappingIdMap idMap, final boolean shouldSort) { this.idMap = idMap; this.shouldSort = shouldSort; if (!shouldSort) { @@ -173,7 +173,7 @@ private static int distinct(final int[] values, final int start, final int len) final class VisitOutgoingNoWeight extends VisitRelationship { - VisitOutgoingNoWeight(final IdMap idMap, final boolean shouldSort) { + VisitOutgoingNoWeight(final MappingIdMap idMap, final boolean shouldSort) { super(idMap, shouldSort); } @@ -185,7 +185,7 @@ void visit(final RelationshipSelectionCursor cursor) { final class VisitIncomingNoWeight extends VisitRelationship { - VisitIncomingNoWeight(final IdMap idMap, final boolean shouldSort) { + VisitIncomingNoWeight(final MappingIdMap idMap, final boolean shouldSort) { super(idMap, shouldSort); } @@ -204,7 +204,7 @@ final class VisitOutgoingWithWeight extends VisitRelationship { VisitOutgoingWithWeight( final Read readOp, final CursorFactory cursors, - final IdMap idMap, + final MappingIdMap idMap, final boolean shouldSort, final WeightMap weights) { super(idMap, shouldSort); @@ -230,7 +230,7 @@ final class VisitIncomingWithWeight extends VisitRelationship { VisitIncomingWithWeight( final Read readOp, final CursorFactory cursors, - final IdMap idMap, + final MappingIdMap idMap, final boolean shouldSort, final WeightMap weights) { super(idMap, shouldSort); @@ -256,7 +256,7 @@ final class VisitUndirectedOutgoingWithWeight extends VisitRelationship { VisitUndirectedOutgoingWithWeight( final Read readOp, final CursorFactory cursors, - final IdMap idMap, + final MappingIdMap idMap, final boolean shouldSort, final WeightMap weights) { super(idMap, shouldSort); diff --git a/core/src/main/java/org/neo4j/graphalgo/core/utils/ParallelUtil.java b/core/src/main/java/org/neo4j/graphalgo/core/utils/ParallelUtil.java index ee912d18d..14a2f68e3 100644 --- a/core/src/main/java/org/neo4j/graphalgo/core/utils/ParallelUtil.java +++ b/core/src/main/java/org/neo4j/graphalgo/core/utils/ParallelUtil.java @@ -22,7 +22,7 @@ import org.neo4j.collection.primitive.PrimitiveLongIterable; import org.neo4j.graphalgo.api.BatchNodeIterable; import org.neo4j.graphalgo.api.HugeBatchNodeIterable; -import org.neo4j.graphalgo.core.IdMap; +import org.neo4j.graphalgo.core.MappingIdMap; import org.neo4j.helpers.Exceptions; import java.util.*; @@ -56,13 +56,13 @@ public static Collection batchIterables(int concurrency, i final int batchSize = nodeCount / concurrency; int numberOfBatches = ParallelUtil.threadSize(batchSize, nodeCount); if (numberOfBatches == 1) { - return Collections.singleton(new IdMap.IdIterable(0, nodeCount)); + return Collections.singleton(new MappingIdMap.IdIterable(0, nodeCount)); } PrimitiveIntIterable[] iterators = new PrimitiveIntIterable[numberOfBatches]; Arrays.setAll(iterators, i -> { int start = i * batchSize; int length = Math.min(batchSize, nodeCount - start); - return new IdMap.IdIterable(start, length); + return new MappingIdMap.IdIterable(start, length); }); return Arrays.asList(iterators); } diff --git a/tests/src/test/java/org/neo4j/graphalgo/core/DirectIdMapTest.java b/tests/src/test/java/org/neo4j/graphalgo/core/DirectIdMapTest.java new file mode 100644 index 000000000..2fc7e7385 --- /dev/null +++ b/tests/src/test/java/org/neo4j/graphalgo/core/DirectIdMapTest.java @@ -0,0 +1,116 @@ +/** + * Copyright (c) 2017 "Neo4j, Inc." + * + * This file is part of Neo4j Graph Algorithms . + * + * Neo4j Graph Algorithms is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.graphalgo.core; + +import org.junit.Test; +import org.neo4j.collection.primitive.PrimitiveIntIterable; +import org.neo4j.collection.primitive.PrimitiveIntIterator; + +import java.util.Collection; +import java.util.stream.IntStream; + +import static org.junit.Assert.*; + +public final class DirectIdMapTest { + + private int size = 20; + + private IntStream ids() { + return IntStream.range(0, size); + } + + @Test + public void basicTest() { + DirectIdMap idMap = new DirectIdMap(size); + assertEquals(size, idMap.nodeCount()); + assertTrue(ids().allMatch(idMap::contains)); + assertTrue(IntStream.range(-100,0).noneMatch(idMap::contains)); + assertTrue(IntStream.range(size,100).noneMatch(idMap::contains)); + + assertTrue(ids().allMatch(i -> idMap.toMappedNodeId(i) == i)); + assertTrue(ids().allMatch(i -> idMap.toOriginalNodeId(i) == i)); + + PrimitiveIntIterator it = idMap.nodeIterator(); + assertTrue(ids().allMatch(i -> it.next() == i)); + } + + @Test + public void shouldReturnSingleIteratorForLargeBatchSize() { + DirectIdMap idMap = new DirectIdMap(size); + + Collection iterables = idMap.batchIterables(100); + assertEquals(1, iterables.size()); + + assertIterables(idMap, ids().toArray(), iterables); + } + + @Test + public void shouldReturnMultipleIteratorsForSmallBatchSize() { + DirectIdMap idMap = new DirectIdMap(size); + + int expectedBatches = size / 3 + (size % 3 > 0 ? 1 :0); + + Collection iterables = idMap.batchIterables(3); + assertEquals(expectedBatches, iterables.size()); + + assertIterables(idMap, ids().toArray(), iterables); + } + + @Test + public void shouldFailForZeroBatchSize() { + DirectIdMap idMap = new DirectIdMap(0); + + try { + idMap.batchIterables(0); + fail(); + } catch (IllegalArgumentException e) { + assertEquals("Invalid batch size: 0", e.getMessage()); + } + } + + @Test + public void shouldFailForNegativeBatchSize() { + DirectIdMap idMap = new DirectIdMap(size); + + int batchSize = -10; + + try { + idMap.batchIterables(batchSize); + fail(); + } catch (IllegalArgumentException e) { + assertEquals("Invalid batch size: " + batchSize, e.getMessage()); + } + } + + private void assertIterables( + final IdMap idMap, + final int[] ids, + final Collection iterables) { + int i = 0; + for (PrimitiveIntIterable iterable : iterables) { + PrimitiveIntIterator iterator = iterable.iterator(); + while (iterator.hasNext()) { + int next = iterator.next(); + long id = ids[i]; + assertEquals(i++, next); + assertEquals(id, idMap.toOriginalNodeId(next)); + } + } + } +} diff --git a/tests/src/test/java/org/neo4j/graphalgo/core/IdMapTest.java b/tests/src/test/java/org/neo4j/graphalgo/core/IdMapTest.java index 0f3d7c0e2..ac2ecdbbd 100644 --- a/tests/src/test/java/org/neo4j/graphalgo/core/IdMapTest.java +++ b/tests/src/test/java/org/neo4j/graphalgo/core/IdMapTest.java @@ -23,7 +23,6 @@ import org.junit.Test; import org.neo4j.collection.primitive.PrimitiveIntIterable; import org.neo4j.collection.primitive.PrimitiveIntIterator; -import org.neo4j.graphalgo.core.utils.paged.AllocationTracker; import java.util.Collection; import java.util.function.IntToLongFunction; @@ -34,7 +33,7 @@ public final class IdMapTest extends RandomizedTest { @Test public void shouldReturnSingleIteratorForLargeBatchSize() throws Exception { - IdMap idMap = new IdMap(20); + MappingIdMap idMap = new MappingIdMap(20); long[] ids = addRandomIds(idMap); idMap.buildMappedIds(); @@ -46,7 +45,7 @@ public void shouldReturnSingleIteratorForLargeBatchSize() throws Exception { @Test public void shouldReturnMultipleIteratorsForSmallBatchSize() throws Exception { - IdMap idMap = new IdMap(20); + MappingIdMap idMap = new MappingIdMap(20); long[] ids = addRandomIds(idMap); idMap.buildMappedIds(); @@ -63,7 +62,7 @@ public void shouldReturnMultipleIteratorsForSmallBatchSize() throws Exception { @Test public void shouldFailForZeroBatchSize() throws Exception { - IdMap idMap = new IdMap(20); + MappingIdMap idMap = new MappingIdMap(20); addRandomIds(idMap); idMap.buildMappedIds(); @@ -77,7 +76,7 @@ public void shouldFailForZeroBatchSize() throws Exception { @Test public void shouldFailForNegativeBatchSize() throws Exception { - IdMap idMap = new IdMap(20); + MappingIdMap idMap = new MappingIdMap(20); addRandomIds(idMap); idMap.buildMappedIds(); @@ -107,7 +106,7 @@ private void assertIterables( } } - private long[] addRandomIds(final IdMap idMap) { + private long[] addRandomIds(final MappingIdMap idMap) { LongHashSet seen = new LongHashSet(); return addSomeIds(idMap, i -> { long id; @@ -118,7 +117,7 @@ private long[] addRandomIds(final IdMap idMap) { }); } - private long[] addSomeIds(final IdMap idMap, IntToLongFunction newId) { + private long[] addSomeIds(final MappingIdMap idMap, IntToLongFunction newId) { int iterations = between(10, 20); long[] ids = new long[iterations]; for (int i = 0; i < iterations; i++) {