Skip to content

Commit df3abef

Browse files
committed
Polishing.
Refine MongoVector factory methods for a more natural adoption and terminology when creating vectors. See #4706
1 parent d37fa9e commit df3abef

File tree

3 files changed

+160
-13
lines changed

3 files changed

+160
-13
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/MongoVector.java

Lines changed: 77 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
import org.springframework.util.ObjectUtils;
2525

2626
/**
27-
* MongoDB-specific extension to {@link Vector} based on Mongo's {@link BinaryVector}. Note that only float32 and int8
28-
* variants can be represented as floating-point numbers. int1 returns an all-zero array for {@link #toFloatArray()} and
29-
* {@link #toDoubleArray()}.
27+
* MongoDB-specific extension to {@link Vector} based on Mongo's {@link BinaryVector}. Note that only {@code float32}
28+
* and {@code int8} variants can be represented as floating-point numbers. {@code int1} throws
29+
* {@link UnsupportedOperationException} when calling {@link #toFloatArray()} and {@link #toDoubleArray()}.
3030
*
3131
* @author Mark Paluch
3232
* @since 4.5
@@ -40,15 +40,65 @@ public class MongoVector implements Vector {
4040
}
4141

4242
/**
43-
* Creates a new {@link MongoVector} from the given {@link BinaryVector}.
43+
* Creates a new binary {@link MongoVector} using the given {@link BinaryVector}.
4444
*
4545
* @param v binary vector representation.
46-
* @return the {@link MongoVector} for the given vector values.
46+
* @return the {@link MongoVector} wrapping {@link BinaryVector}.
4747
*/
4848
public static MongoVector of(BinaryVector v) {
4949
return new MongoVector(v);
5050
}
5151

52+
/**
53+
* Creates a new binary {@link MongoVector} using the given {@code data}.
54+
* <p>
55+
* A {@link BinaryVector.DataType#INT8} vector is a vector of 8-bit signed integers where each byte in the vector
56+
* represents an element of a vector, with values in the range {@code [-128, 127]}.
57+
* <p>
58+
* NOTE: The byte array is not copied; changes to the provided array will be referenced in the created
59+
* {@code MongoVector} instance.
60+
*
61+
* @param data the byte array representing the {@link BinaryVector.DataType#INT8} vector data.
62+
* @return the {@link MongoVector} containing the given vector values to be represented as binary {@code int8}.
63+
*/
64+
public static MongoVector ofInt8(byte[] data) {
65+
return of(BinaryVector.int8Vector(data));
66+
}
67+
68+
/**
69+
* Creates a new binary {@link MongoVector} using the given {@code data}.
70+
* <p>
71+
* A {@link BinaryVector.DataType#FLOAT32} vector is a vector of floating-point numbers, where each element in the
72+
* vector is a {@code float}.
73+
* <p>
74+
* NOTE: The float array is not copied; changes to the provided array will be referenced in the created
75+
* {@code MongoVector} instance.
76+
*
77+
* @param data the float array representing the {@link BinaryVector.DataType#FLOAT32} vector data.
78+
* @return the {@link MongoVector} containing the given vector values to be represented as binary {@code float32}.
79+
*/
80+
public static MongoVector ofFloat(float... data) {
81+
return of(BinaryVector.floatVector(data));
82+
}
83+
84+
/**
85+
* Creates a new binary {@link MongoVector} from the given {@link Vector}.
86+
* <p>
87+
* A {@link BinaryVector.DataType#FLOAT32} vector is a vector of floating-point numbers, where each element in the
88+
* vector is a {@code float}. The given {@link Vector} must be able to return a {@link Vector#toFloatArray() float}
89+
* array.
90+
* <p>
91+
* NOTE: The float array is not copied; changes to the provided array will be referenced in the created
92+
* {@code MongoVector} instance.
93+
*
94+
* @param v the
95+
* @return the {@link MongoVector} using vector values from the given {@link Vector} to be represented as binary
96+
* float32.
97+
*/
98+
public static MongoVector fromFloat(Vector v) {
99+
return of(BinaryVector.floatVector(v.toFloatArray()));
100+
}
101+
52102
@Override
53103
public Class<? extends Number> getType() {
54104

@@ -90,6 +140,11 @@ public int size() {
90140
return 0;
91141
}
92142

143+
/**
144+
* {@inheritDoc}
145+
*
146+
* @throws UnsupportedOperationException if the underlying data type is {@code int1} {@link PackedBitBinaryVector}.
147+
*/
93148
@Override
94149
public float[] toFloatArray() {
95150

@@ -102,14 +157,22 @@ public float[] toFloatArray() {
102157

103158
if (v instanceof Int8BinaryVector i) {
104159

105-
float[] result = new float[i.getData().length];
106-
System.arraycopy(i.getData(), 0, result, 0, result.length);
160+
byte[] data = i.getData();
161+
float[] result = new float[data.length];
162+
for (int j = 0; j < data.length; j++) {
163+
result[j] = data[j];
164+
}
107165
return result;
108166
}
109167

110-
return new float[size()];
168+
throw new UnsupportedOperationException("Cannot return float array for " + v.getClass());
111169
}
112170

171+
/**
172+
* {@inheritDoc}
173+
*
174+
* @throws UnsupportedOperationException if the underlying data type is {@code int1} {@link PackedBitBinaryVector}.
175+
*/
113176
@Override
114177
public double[] toDoubleArray() {
115178

@@ -126,12 +189,15 @@ public double[] toDoubleArray() {
126189

127190
if (v instanceof Int8BinaryVector i) {
128191

129-
double[] result = new double[i.getData().length];
130-
System.arraycopy(i.getData(), 0, result, 0, result.length);
192+
byte[] data = i.getData();
193+
double[] result = new double[data.length];
194+
for (int j = 0; j < data.length; j++) {
195+
result[j] = data[j];
196+
}
131197
return result;
132198
}
133199

134-
return new double[size()];
200+
throw new UnsupportedOperationException("Cannot return double array for " + v.getClass());
135201
}
136202

137203
@Override

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MongoConvertersIntegrationTests.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,29 +138,31 @@ public void shouldReadAndWriteBinFloat32Vectors() {
138138

139139
WithVectors source = new WithVectors();
140140
source.binVector = BinaryVector.floatVector(new float[] { 1.1f, 2.2f, 3.3f });
141-
source.vector = MongoVector.of(source.binVector);
141+
source.vector = MongoVector.ofFloat(new float[] { 1.1f, 2.2f, 3.3f });
142142

143143
template.save(source);
144144

145145
WithVectors loaded = template.findOne(query(where("id").is(source.id)), WithVectors.class);
146146

147147
assertThat(loaded.vector).isEqualTo(source.vector);
148148
assertThat(loaded.binVector).isEqualTo(source.binVector);
149+
assertThat(loaded.binVector).isEqualTo(source.vector.getSource());
149150
}
150151

151152
@Test // GH-4706
152153
public void shouldReadAndWriteBinInt8Vectors() {
153154

154155
WithVectors source = new WithVectors();
155156
source.binVector = BinaryVector.int8Vector(new byte[] { 1, 2, 3 });
156-
source.vector = MongoVector.of(source.binVector);
157+
source.vector = MongoVector.ofInt8(new byte[] { 1, 2, 3 });
157158

158159
template.save(source);
159160

160161
WithVectors loaded = template.findOne(query(where("id").is(source.id)), WithVectors.class);
161162

162163
assertThat(loaded.vector).isEqualTo(source.vector);
163164
assertThat(loaded.binVector).isEqualTo(source.binVector);
165+
assertThat(loaded.binVector).isEqualTo(source.vector.getSource());
164166
}
165167

166168
@Test // GH-4706
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Copyright 2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.data.mongodb.core.mapping;
17+
18+
import static org.springframework.data.mongodb.test.util.Assertions.*;
19+
20+
import org.bson.BinaryVector;
21+
import org.bson.Float32BinaryVector;
22+
import org.junit.jupiter.api.Test;
23+
24+
import org.springframework.data.domain.Vector;
25+
26+
/**
27+
* Unit tests for {@link MongoVector}.
28+
*
29+
* @author Mark Paluch
30+
*/
31+
class MongoVectorUnitTests {
32+
33+
@Test // GH-4706
34+
void shouldReturnInt8AsFloatingPoints() {
35+
36+
MongoVector vector = MongoVector.ofInt8(new byte[] { 1, 2, 3 });
37+
38+
assertThat(vector.toDoubleArray()).contains(1, 2, 3);
39+
assertThat(vector.toFloatArray()).contains(1, 2, 3);
40+
}
41+
42+
@Test // GH-4706
43+
void shouldReturnFloatAsFloatingPoints() {
44+
45+
MongoVector vector = MongoVector.ofFloat(1f, 2f, 3f);
46+
47+
assertThat(vector.toDoubleArray()).contains(1, 2, 3);
48+
assertThat(vector.toFloatArray()).contains(1, 2, 3);
49+
}
50+
51+
@Test // GH-4706
52+
void ofFloatIsNotEqualToVectorOf() {
53+
54+
MongoVector mv = MongoVector.ofFloat(1f, 2f, 3f);
55+
Vector v = Vector.of(1f, 2f, 3f);
56+
57+
assertThat(v).isNotEqualTo(mv);
58+
}
59+
60+
@Test // GH-4706
61+
void mongoVectorCanAdaptToFloatVector() {
62+
63+
Vector v = Vector.of(1f, 2f, 3f);
64+
MongoVector mv = MongoVector.fromFloat(v);
65+
66+
assertThat(mv.toFloatArray()).isEqualTo(v.toFloatArray());
67+
assertThat(mv.getSource()).isInstanceOf(Float32BinaryVector.class);
68+
}
69+
70+
@Test // GH-4706
71+
void shouldNotReturnFloatsForPackedBit() {
72+
73+
MongoVector vector = MongoVector.of(BinaryVector.packedBitVector(new byte[] { 1, 2, 3 }, (byte) 1));
74+
75+
assertThatExceptionOfType(UnsupportedOperationException.class).isThrownBy(vector::toFloatArray);
76+
assertThatExceptionOfType(UnsupportedOperationException.class).isThrownBy(vector::toDoubleArray);
77+
}
78+
79+
}

0 commit comments

Comments
 (0)