diff --git a/firebase-firestore/CHANGELOG.md b/firebase-firestore/CHANGELOG.md index cf3d3772d64..4f458570ec4 100644 --- a/firebase-firestore/CHANGELOG.md +++ b/firebase-firestore/CHANGELOG.md @@ -1,5 +1,5 @@ # Unreleased - +* [feature] Add support for the VectorValue type. [#6154](//github.com/firebase/firebase-android-sdk/pull/6154) # 25.0.0 * [feature] Enable queries with range & inequality filters on multiple fields. [#5729](//github.com/firebase/firebase-android-sdk/pull/5729) diff --git a/firebase-firestore/README.md b/firebase-firestore/README.md index 65bcaa62520..73df2a91094 100644 --- a/firebase-firestore/README.md +++ b/firebase-firestore/README.md @@ -113,7 +113,7 @@ it on behalf. Run below to format Java code: ```bash -./gradlew :firebase-firestore:googleJavaFormat +./gradlew :firebase-firestore:spotlessApply ``` See [here](../README.md#code-formatting) if you want to be able to format code diff --git a/firebase-firestore/api.txt b/firebase-firestore/api.txt index 174d4b07e4d..3c36326eec6 100644 --- a/firebase-firestore/api.txt +++ b/firebase-firestore/api.txt @@ -123,6 +123,7 @@ package com.google.firebase.firestore { method @Nullable public String getString(@NonNull String); method @Nullable public com.google.firebase.Timestamp getTimestamp(@NonNull String); method @Nullable public com.google.firebase.Timestamp getTimestamp(@NonNull String, @NonNull com.google.firebase.firestore.DocumentSnapshot.ServerTimestampBehavior); + method @Nullable public com.google.firebase.firestore.VectorValue getVectorValue(@NonNull String); method @Nullable public T toObject(@NonNull Class); method @Nullable public T toObject(@NonNull Class, @NonNull com.google.firebase.firestore.DocumentSnapshot.ServerTimestampBehavior); } @@ -152,6 +153,7 @@ package com.google.firebase.firestore { method @NonNull public static com.google.firebase.firestore.FieldValue increment(long); method @NonNull public static com.google.firebase.firestore.FieldValue increment(double); method @NonNull public static com.google.firebase.firestore.FieldValue serverTimestamp(); + method @NonNull public static com.google.firebase.firestore.VectorValue vector(@NonNull double[]); } public class Filter { @@ -554,6 +556,10 @@ package com.google.firebase.firestore { method @NonNull public com.google.firebase.firestore.TransactionOptions.Builder setMaxAttempts(int); } + public class VectorValue { + method @NonNull public double[] toArray(); + } + public class WriteBatch { method @NonNull public com.google.android.gms.tasks.Task commit(); method @NonNull public com.google.firebase.firestore.WriteBatch delete(@NonNull com.google.firebase.firestore.DocumentReference); diff --git a/firebase-firestore/src/androidTest/java/com/google/firebase/firestore/NumericTransformsTest.java b/firebase-firestore/src/androidTest/java/com/google/firebase/firestore/NumericTransformsTest.java index 47651f0f29d..af4687eeb12 100644 --- a/firebase-firestore/src/androidTest/java/com/google/firebase/firestore/NumericTransformsTest.java +++ b/firebase-firestore/src/androidTest/java/com/google/firebase/firestore/NumericTransformsTest.java @@ -34,7 +34,7 @@ @RunWith(AndroidJUnit4.class) public class NumericTransformsTest { - private static final double DOUBLE_EPSILON = 0.000001; + public static final double DOUBLE_EPSILON = 0.000001; // A document reference to read and write to. private DocumentReference docRef; diff --git a/firebase-firestore/src/androidTest/java/com/google/firebase/firestore/VectorTest.java b/firebase-firestore/src/androidTest/java/com/google/firebase/firestore/VectorTest.java new file mode 100644 index 00000000000..8f51a3800d8 --- /dev/null +++ b/firebase-firestore/src/androidTest/java/com/google/firebase/firestore/VectorTest.java @@ -0,0 +1,375 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.firebase.firestore; + +import static com.google.firebase.firestore.NumericTransformsTest.DOUBLE_EPSILON; +import static com.google.firebase.firestore.testutil.IntegrationTestUtil.checkOnlineAndOfflineResultsMatch; +import static com.google.firebase.firestore.testutil.IntegrationTestUtil.testCollection; +import static com.google.firebase.firestore.testutil.IntegrationTestUtil.testDocument; +import static com.google.firebase.firestore.testutil.IntegrationTestUtil.waitFor; +import static com.google.firebase.firestore.testutil.TestUtil.map; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.firebase.firestore.testutil.EventAccumulator; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Semaphore; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import org.junit.Test; +import org.junit.runner.RunWith; + +@RunWith(AndroidJUnit4.class) +public class VectorTest { + + @Test + public void writeAndReadVectorEmbeddings() throws ExecutionException, InterruptedException { + Map expected = new HashMap<>(); + DocumentReference randomDoc = testDocument(); + + waitFor( + randomDoc.set( + map( + "vector0", + FieldValue.vector(new double[] {0.0}), + "vector1", + FieldValue.vector(new double[] {1, 2, 3.99})))); + waitFor( + randomDoc.set( + map( + "vector0", + FieldValue.vector(new double[] {0.0}), + "vector1", + FieldValue.vector(new double[] {1, 2, 3.99}), + "vector2", + FieldValue.vector(new double[] {0, 0, 0})))); + waitFor(randomDoc.update(map("vector3", FieldValue.vector(new double[] {-1, -200, -9999})))); + + expected.put("vector0", FieldValue.vector(new double[] {0.0})); + expected.put("vector1", FieldValue.vector(new double[] {1, 2, 3.99})); + expected.put("vector2", FieldValue.vector(new double[] {0, 0, 0})); + expected.put("vector3", FieldValue.vector(new double[] {-1, -200, -9999})); + + DocumentSnapshot actual = waitFor(randomDoc.get()); + + assertTrue(actual.get("vector0") instanceof VectorValue); + assertTrue(actual.get("vector1") instanceof VectorValue); + assertTrue(actual.get("vector2") instanceof VectorValue); + assertTrue(actual.get("vector3") instanceof VectorValue); + + assertArrayEquals( + expected.get("vector0").toArray(), + actual.get("vector0", VectorValue.class).toArray(), + DOUBLE_EPSILON); + assertArrayEquals( + expected.get("vector1").toArray(), + actual.get("vector1", VectorValue.class).toArray(), + DOUBLE_EPSILON); + assertArrayEquals( + expected.get("vector2").toArray(), + actual.get("vector2", VectorValue.class).toArray(), + DOUBLE_EPSILON); + assertArrayEquals( + expected.get("vector3").toArray(), + actual.get("vector3", VectorValue.class).toArray(), + DOUBLE_EPSILON); + + assertArrayEquals( + expected.get("vector0").toArray(), + actual.getVectorValue("vector0").toArray(), + DOUBLE_EPSILON); + assertArrayEquals( + expected.get("vector1").toArray(), + actual.getVectorValue("vector1").toArray(), + DOUBLE_EPSILON); + assertArrayEquals( + expected.get("vector2").toArray(), + actual.getVectorValue("vector2").toArray(), + DOUBLE_EPSILON); + assertArrayEquals( + expected.get("vector3").toArray(), + actual.getVectorValue("vector3").toArray(), + DOUBLE_EPSILON); + } + + @Test + public void listenToDocumentsWithVectors() throws Throwable { + final Semaphore semaphore = new Semaphore(0); + ListenerRegistration registration = null; + CollectionReference randomColl = testCollection(); + DocumentReference ref = randomColl.document(); + AtomicReference failureMessage = new AtomicReference(null); + int totalPermits = 5; + + try { + registration = + randomColl + .whereEqualTo("purpose", "vector tests") + .addSnapshotListener( + (value, error) -> { + try { + DocumentSnapshot docSnap = + value.isEmpty() ? null : value.getDocuments().get(0); + + switch (semaphore.availablePermits()) { + case 0: + assertNull(docSnap); + ref.set( + map( + "purpose", "vector tests", + "vector0", FieldValue.vector(new double[] {0.0}), + "vector1", FieldValue.vector(new double[] {1, 2, 3.99}))); + break; + case 1: + assertNotNull(docSnap); + + assertEquals( + docSnap.getVectorValue("vector0"), + FieldValue.vector(new double[] {0.0})); + assertEquals( + docSnap.getVectorValue("vector1"), + FieldValue.vector(new double[] {1, 2, 3.99})); + + ref.set( + map( + "purpose", + "vector tests", + "vector0", + FieldValue.vector(new double[] {0.0}), + "vector1", + FieldValue.vector(new double[] {1, 2, 3.99}), + "vector2", + FieldValue.vector(new double[] {0, 0, 0}))); + break; + case 2: + assertNotNull(docSnap); + + assertEquals( + docSnap.getVectorValue("vector0"), + FieldValue.vector(new double[] {0.0})); + assertEquals( + docSnap.getVectorValue("vector1"), + FieldValue.vector(new double[] {1, 2, 3.99})); + assertEquals( + docSnap.getVectorValue("vector2"), + FieldValue.vector(new double[] {0, 0, 0})); + + ref.update( + map("vector3", FieldValue.vector(new double[] {-1, -200, -999}))); + break; + case 3: + assertNotNull(docSnap); + + assertEquals( + docSnap.getVectorValue("vector0"), + FieldValue.vector(new double[] {0.0})); + assertEquals( + docSnap.getVectorValue("vector1"), + FieldValue.vector(new double[] {1, 2, 3.99})); + assertEquals( + docSnap.getVectorValue("vector2"), + FieldValue.vector(new double[] {0, 0, 0})); + assertEquals( + docSnap.getVectorValue("vector3"), + FieldValue.vector(new double[] {-1, -200, -999})); + + ref.delete(); + break; + case 4: + assertNull(docSnap); + break; + } + } catch (Throwable t) { + failureMessage.set(t); + semaphore.release(totalPermits); + } + + semaphore.release(); + }); + + semaphore.acquire(totalPermits); + } finally { + if (registration != null) { + registration.remove(); + } + + if (failureMessage.get() != null) { + throw failureMessage.get(); + } + } + } + + /** Verifies that the SDK orders vector fields the same way as the backend. */ + @Test + public void vectorFieldOrder() throws Exception { + // We validate that the SDK orders the vector field the same way as the backend + // by comparing the sort order of vector fields from a Query.get() and + // Query.addSnapshotListener(). Query.addSnapshotListener() will return sort order + // of the SDK, and Query.get() will return sort order of the backend. + + CollectionReference randomColl = testCollection(); + + // Test data in the order that we expect the backend to sort it. + List> docsInOrder = + Arrays.asList( + map("embedding", Arrays.asList(1, 2, 3, 4, 5, 6)), + map("embedding", Arrays.asList(100)), + map("embedding", FieldValue.vector(new double[] {Double.NEGATIVE_INFINITY})), + map("embedding", FieldValue.vector(new double[] {-100})), + map("embedding", FieldValue.vector(new double[] {100})), + map("embedding", FieldValue.vector(new double[] {Double.POSITIVE_INFINITY})), + map("embedding", FieldValue.vector(new double[] {1, 2})), + map("embedding", FieldValue.vector(new double[] {2, 2})), + map("embedding", FieldValue.vector(new double[] {1, 2, 3})), + map("embedding", FieldValue.vector(new double[] {1, 2, 3, 4})), + map("embedding", FieldValue.vector(new double[] {1, 2, 3, 4, 5})), + map("embedding", FieldValue.vector(new double[] {1, 2, 100, 4, 4})), + map("embedding", FieldValue.vector(new double[] {100, 2, 3, 4, 5})), + map("embedding", map()), + map("embedding", map("HELLO", "WORLD")), + map("embedding", map("hello", "world"))); + + // Add docs and store doc IDs + List docIds = new ArrayList(); + for (int i = 0; i < docsInOrder.size(); i++) { + DocumentReference docRef = waitFor(randomColl.add(docsInOrder.get(i))); + docIds.add(docRef.getId()); + } + + // Test query + Query orderedQuery = randomColl.orderBy("embedding"); + + // Run query with snapshot listener + EventAccumulator eventAccumulator = new EventAccumulator(); + ListenerRegistration registration = + orderedQuery.addSnapshotListener(eventAccumulator.listener()); + + List watchIds = new ArrayList<>(); + try { + // Get doc IDs from snapshot listener + QuerySnapshot querySnapshot = eventAccumulator.await(); + watchIds = + querySnapshot.getDocuments().stream() + .map(documentSnapshot -> documentSnapshot.getId()) + .collect(Collectors.toList()); + } finally { + registration.remove(); + } + + // Run query with get() and get doc IDs + QuerySnapshot querySnapshot = waitFor(orderedQuery.get()); + List getIds = + querySnapshot.getDocuments().stream().map(ds -> ds.getId()).collect(Collectors.toList()); + + // Assert that get and snapshot listener requests sort docs in the same, expected order + assertArrayEquals(docIds.toArray(new String[0]), getIds.toArray(new String[0])); + assertArrayEquals(docIds.toArray(new String[0]), watchIds.toArray(new String[0])); + } + + /** Verifies that the SDK orders vector fields the same way for online and offline queries*/ + @Test + public void vectorFieldOrderOnlineAndOffline() throws Exception { + CollectionReference randomColl = testCollection(); + + // Test data in the order that we expect the backend to sort it. + List> docsInOrder = + Arrays.asList( + map("embedding", Arrays.asList(1, 2, 3, 4, 5, 6)), + map("embedding", Arrays.asList(100)), + map("embedding", FieldValue.vector(new double[] {Double.NEGATIVE_INFINITY})), + map("embedding", FieldValue.vector(new double[] {-100})), + map("embedding", FieldValue.vector(new double[] {100})), + map("embedding", FieldValue.vector(new double[] {Double.POSITIVE_INFINITY})), + map("embedding", FieldValue.vector(new double[] {1, 2})), + map("embedding", FieldValue.vector(new double[] {2, 2})), + map("embedding", FieldValue.vector(new double[] {1, 2, 3})), + map("embedding", FieldValue.vector(new double[] {1, 2, 3, 4})), + map("embedding", FieldValue.vector(new double[] {1, 2, 3, 4, 5})), + map("embedding", FieldValue.vector(new double[] {1, 2, 100, 4, 4})), + map("embedding", FieldValue.vector(new double[] {100, 2, 3, 4, 5})), + map("embedding", map()), + map("embedding", map("HELLO", "WORLD")), + map("embedding", map("hello", "world"))); + + // Add docs and store doc IDs + List docIds = new ArrayList(); + for (int i = 0; i < docsInOrder.size(); i++) { + DocumentReference docRef = waitFor(randomColl.add(docsInOrder.get(i))); + docIds.add(docRef.getId()); + } + + // Test query + Query orderedQuery = randomColl.orderBy("embedding"); + + // Run query with snapshot listener + checkOnlineAndOfflineResultsMatch(orderedQuery, docIds.stream().toArray(String[]::new)); + } + + /** Verifies that the SDK filters vector fields the same way for online and offline queries*/ + @Test + public void vectorFieldFilterOnlineAndOffline() throws Exception { + CollectionReference randomColl = testCollection(); + + // Test data in the order that we expect the backend to sort it. + List> docsInOrder = + Arrays.asList( + map("embedding", Arrays.asList(1, 2, 3, 4, 5, 6)), + map("embedding", Arrays.asList(100)), + map("embedding", FieldValue.vector(new double[] {Double.NEGATIVE_INFINITY})), + map("embedding", FieldValue.vector(new double[] {-100})), + map("embedding", FieldValue.vector(new double[] {100})), + map("embedding", FieldValue.vector(new double[] {Double.POSITIVE_INFINITY})), + map("embedding", FieldValue.vector(new double[] {1, 2})), + map("embedding", FieldValue.vector(new double[] {2, 2})), + map("embedding", FieldValue.vector(new double[] {1, 2, 3})), + map("embedding", FieldValue.vector(new double[] {1, 2, 3, 4})), + map("embedding", FieldValue.vector(new double[] {1, 2, 3, 4, 5})), + map("embedding", FieldValue.vector(new double[] {1, 2, 100, 4, 4})), + map("embedding", FieldValue.vector(new double[] {100, 2, 3, 4, 5})), + map("embedding", map()), + map("embedding", map("HELLO", "WORLD")), + map("embedding", map("hello", "world"))); + + // Add docs and store doc IDs + List docIds = new ArrayList(); + for (int i = 0; i < docsInOrder.size(); i++) { + DocumentReference docRef = waitFor(randomColl.add(docsInOrder.get(i))); + docIds.add(docRef.getId()); + } + + Query orderedQueryLessThan = + randomColl + .orderBy("embedding") + .whereLessThan("embedding", FieldValue.vector(new double[] {1, 2, 100, 4, 4})); + checkOnlineAndOfflineResultsMatch( + orderedQueryLessThan, docIds.subList(2, 11).stream().toArray(String[]::new)); + + Query orderedQueryGreaterThan = + randomColl + .orderBy("embedding") + .whereGreaterThan("embedding", FieldValue.vector(new double[] {1, 2, 100, 4, 4})); + checkOnlineAndOfflineResultsMatch( + orderedQueryGreaterThan, docIds.subList(12, 13).stream().toArray(String[]::new)); + } +} diff --git a/firebase-firestore/src/main/java/com/google/firebase/firestore/DocumentSnapshot.java b/firebase-firestore/src/main/java/com/google/firebase/firestore/DocumentSnapshot.java index eddc041768b..526d12ec498 100644 --- a/firebase-firestore/src/main/java/com/google/firebase/firestore/DocumentSnapshot.java +++ b/firebase-firestore/src/main/java/com/google/firebase/firestore/DocumentSnapshot.java @@ -484,6 +484,19 @@ public DocumentReference getReference() { return new DocumentReference(key, firestore); } + /** + * Returns the value of the field as a {@link VectorValue} or + * `null` if the field does not exist in the document. + * + * @param field The path to the field. + * @throws RuntimeException if the value is not a VectorValue. + * @return The value of the field. + */ + @Nullable + public VectorValue getVectorValue(@NonNull String field) { + return (VectorValue) get(field); + } + @Nullable private T getTypedValue(String field, Class clazz) { checkNotNull(field, "Provided field must not be null."); diff --git a/firebase-firestore/src/main/java/com/google/firebase/firestore/FieldValue.java b/firebase-firestore/src/main/java/com/google/firebase/firestore/FieldValue.java index 985ef6fd83d..48f67e50e12 100644 --- a/firebase-firestore/src/main/java/com/google/firebase/firestore/FieldValue.java +++ b/firebase-firestore/src/main/java/com/google/firebase/firestore/FieldValue.java @@ -182,4 +182,15 @@ public static FieldValue increment(long l) { public static FieldValue increment(double l) { return new NumericIncrementFieldValue(l); } + + /** + * Creates a new {@link VectorValue} constructed with a copy of the given array of doubles. + * + * @param values Array of doubles to be copied to create a {@link VectorValue}. + * @return A new {@link VectorValue} constructed with a copy of the given array of doubles. + */ + @NonNull + public static VectorValue vector(@NonNull double[] values) { + return new VectorValue(values); + } } diff --git a/firebase-firestore/src/main/java/com/google/firebase/firestore/UserDataReader.java b/firebase-firestore/src/main/java/com/google/firebase/firestore/UserDataReader.java index c7c399be117..297479d0262 100644 --- a/firebase-firestore/src/main/java/com/google/firebase/firestore/UserDataReader.java +++ b/firebase-firestore/src/main/java/com/google/firebase/firestore/UserDataReader.java @@ -32,6 +32,7 @@ import com.google.firebase.firestore.model.DatabaseId; import com.google.firebase.firestore.model.FieldPath; import com.google.firebase.firestore.model.ObjectValue; +import com.google.firebase.firestore.model.Values; import com.google.firebase.firestore.model.mutation.ArrayTransformOperation; import com.google.firebase.firestore.model.mutation.FieldMask; import com.google.firebase.firestore.model.mutation.NumericIncrementTransformOperation; @@ -440,6 +441,8 @@ private Value parseScalarValue(Object input, ParseContext context) { databaseId.getDatabaseId(), ((DocumentReference) input).getPath())) .build(); + } else if (input instanceof VectorValue) { + return parseVectorValue(((VectorValue) input), context); } else if (input.getClass().isArray()) { throw context.createError("Arrays are not supported; use a List instead"); } else { @@ -447,6 +450,15 @@ private Value parseScalarValue(Object input, ParseContext context) { } } + private Value parseVectorValue(VectorValue vector, ParseContext context) { + MapValue.Builder mapBuilder = MapValue.newBuilder(); + + mapBuilder.putFields(Values.TYPE_KEY, Values.VECTOR_VALUE_TYPE); + mapBuilder.putFields(Values.VECTOR_MAP_VECTORS_KEY, parseData(vector.toList(), context)); + + return Value.newBuilder().setMapValue(mapBuilder).build(); + } + private Value parseTimestamp(Timestamp timestamp) { // Firestore backend truncates precision down to microseconds. To ensure offline mode works // the same with regards to truncation, perform the truncation immediately without waiting for diff --git a/firebase-firestore/src/main/java/com/google/firebase/firestore/UserDataWriter.java b/firebase-firestore/src/main/java/com/google/firebase/firestore/UserDataWriter.java index ab831aca914..d6ac7b90bba 100644 --- a/firebase-firestore/src/main/java/com/google/firebase/firestore/UserDataWriter.java +++ b/firebase-firestore/src/main/java/com/google/firebase/firestore/UserDataWriter.java @@ -27,6 +27,7 @@ import static com.google.firebase.firestore.model.Values.TYPE_ORDER_SERVER_TIMESTAMP; import static com.google.firebase.firestore.model.Values.TYPE_ORDER_STRING; import static com.google.firebase.firestore.model.Values.TYPE_ORDER_TIMESTAMP; +import static com.google.firebase.firestore.model.Values.TYPE_ORDER_VECTOR; import static com.google.firebase.firestore.model.Values.typeOrder; import static com.google.firebase.firestore.util.Assert.fail; @@ -34,6 +35,7 @@ import com.google.firebase.Timestamp; import com.google.firebase.firestore.model.DatabaseId; import com.google.firebase.firestore.model.DocumentKey; +import com.google.firebase.firestore.model.Values; import com.google.firebase.firestore.util.Logger; import com.google.firestore.v1.ArrayValue; import com.google.firestore.v1.Value; @@ -86,6 +88,8 @@ public Object convertValue(Value value) { case TYPE_ORDER_GEOPOINT: return new GeoPoint( value.getGeoPointValue().getLatitude(), value.getGeoPointValue().getLongitude()); + case TYPE_ORDER_VECTOR: + return convertVectorValue(value.getMapValue().getFieldsMap()); default: throw fail("Unknown value type: " + value.getValueTypeCase()); } @@ -99,6 +103,18 @@ Map convertObject(Map mapValue) { return result; } + VectorValue convertVectorValue(Map mapValue) { + List values = + mapValue.get(Values.VECTOR_MAP_VECTORS_KEY).getArrayValue().getValuesList(); + + double[] doubles = new double[values.size()]; + for (int i = 0; i < values.size(); i++) { + doubles[i] = values.get(i).getDoubleValue(); + } + + return new VectorValue(doubles); + } + private Object convertServerTimestamp(Value serverTimestampValue) { switch (serverTimestampBehavior) { case PREVIOUS: diff --git a/firebase-firestore/src/main/java/com/google/firebase/firestore/VectorValue.java b/firebase-firestore/src/main/java/com/google/firebase/firestore/VectorValue.java new file mode 100644 index 00000000000..2f355648376 --- /dev/null +++ b/firebase-firestore/src/main/java/com/google/firebase/firestore/VectorValue.java @@ -0,0 +1,81 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.firebase.firestore; + +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * Represent a vector type in Firestore documents. + * Create an instance with {@link FieldValue#vector(double[])}. + */ +public class VectorValue { + private final double[] values; + + VectorValue(@Nullable double[] values) { + this.values = (values == null) ? new double[] {} : values.clone(); + } + + /** + * Returns a representation of the vector as an array of doubles. + * + * @return A representation of the vector as an array of doubles + */ + @NonNull + public double[] toArray() { + return this.values.clone(); + } + + /** + * Package private. + * Returns a representation of the vector as a List. + * + * @return A representation of the vector as an List + */ + @NonNull + List toList() { + ArrayList result = new ArrayList(this.values.length); + for (int i = 0; i < this.values.length; i++) { + result.add(i, this.values[i]); + } + return result; + } + + /** + * Returns true if this VectorValue is equal to the provided object. + * + * @param obj The object to compare against. + * @return Whether this VectorValue is equal to the provided object. + */ + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + VectorValue otherArray = (VectorValue) obj; + return Arrays.equals(this.values, otherArray.values); + } + + @Override + public int hashCode() { + return Arrays.hashCode(values); + } +} diff --git a/firebase-firestore/src/main/java/com/google/firebase/firestore/core/Target.java b/firebase-firestore/src/main/java/com/google/firebase/firestore/core/Target.java index 6554380f659..d058e15659e 100644 --- a/firebase-firestore/src/main/java/com/google/firebase/firestore/core/Target.java +++ b/firebase-firestore/src/main/java/com/google/firebase/firestore/core/Target.java @@ -246,7 +246,7 @@ private Pair getAscendingBound( switch (fieldFilter.getOperator()) { case LESS_THAN: case LESS_THAN_OR_EQUAL: - filterValue = Values.getLowerBound(fieldFilter.getValue().getValueTypeCase()); + filterValue = Values.getLowerBound(fieldFilter.getValue()); break; case EQUAL: case IN: @@ -311,7 +311,7 @@ private Pair getDescendingBound( switch (fieldFilter.getOperator()) { case GREATER_THAN_OR_EQUAL: case GREATER_THAN: - filterValue = Values.getUpperBound(fieldFilter.getValue().getValueTypeCase()); + filterValue = Values.getUpperBound(fieldFilter.getValue()); filterInclusive = false; break; case EQUAL: diff --git a/firebase-firestore/src/main/java/com/google/firebase/firestore/index/FirestoreIndexValueWriter.java b/firebase-firestore/src/main/java/com/google/firebase/firestore/index/FirestoreIndexValueWriter.java index c197c8d02f7..f275634957a 100644 --- a/firebase-firestore/src/main/java/com/google/firebase/firestore/index/FirestoreIndexValueWriter.java +++ b/firebase-firestore/src/main/java/com/google/firebase/firestore/index/FirestoreIndexValueWriter.java @@ -41,6 +41,7 @@ public class FirestoreIndexValueWriter { public static final int INDEX_TYPE_REFERENCE = 37; public static final int INDEX_TYPE_GEOPOINT = 45; public static final int INDEX_TYPE_ARRAY = 50; + public static final int INDEX_TYPE_VECTOR = 53; public static final int INDEX_TYPE_MAP = 55; public static final int INDEX_TYPE_REFERENCE_SEGMENT = 60; @@ -114,6 +115,9 @@ private void writeIndexValueAux(Value indexValue, DirectionalIndexByteEncoder en if (Values.isMaxValue(indexValue)) { writeValueTypeLabel(encoder, Integer.MAX_VALUE); break; + } else if (Values.isVectorValue(indexValue)) { + writeIndexVector(indexValue.getMapValue(), encoder); + break; } writeIndexMap(indexValue.getMapValue(), encoder); writeTruncationMarker(encoder); @@ -138,6 +142,21 @@ private void writeUnlabeledIndexString( encoder.writeString(stringIndexValue); } + private void writeIndexVector(MapValue mapIndexValue, DirectionalIndexByteEncoder encoder) { + Map map = mapIndexValue.getFieldsMap(); + String key = Values.VECTOR_MAP_VECTORS_KEY; + writeValueTypeLabel(encoder, INDEX_TYPE_VECTOR); + + // Vectors sort first by length + int length = map.get(key).getArrayValue().getValuesCount(); + writeValueTypeLabel(encoder, INDEX_TYPE_NUMBER); + encoder.writeLong(length); + + // Vectors then sort by position value + this.writeIndexString(key, encoder); + this.writeIndexValueAux(map.get(key), encoder); + } + private void writeIndexMap(MapValue mapIndexValue, DirectionalIndexByteEncoder encoder) { writeValueTypeLabel(encoder, INDEX_TYPE_MAP); for (Map.Entry entry : mapIndexValue.getFieldsMap().entrySet()) { diff --git a/firebase-firestore/src/main/java/com/google/firebase/firestore/model/Values.java b/firebase-firestore/src/main/java/com/google/firebase/firestore/model/Values.java index 2830f810e42..26b06a5d2cb 100644 --- a/firebase-firestore/src/main/java/com/google/firebase/firestore/model/Values.java +++ b/firebase-firestore/src/main/java/com/google/firebase/firestore/model/Values.java @@ -37,14 +37,28 @@ import java.util.TreeMap; public class Values { + public static final String TYPE_KEY = "__type__"; public static final Value NAN_VALUE = Value.newBuilder().setDoubleValue(Double.NaN).build(); public static final Value NULL_VALUE = Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build(); public static final Value MIN_VALUE = NULL_VALUE; - private static final Value MAX_VALUE_TYPE = Value.newBuilder().setStringValue("__max__").build(); + public static final Value MAX_VALUE_TYPE = Value.newBuilder().setStringValue("__max__").build(); public static final Value MAX_VALUE = Value.newBuilder() - .setMapValue(MapValue.newBuilder().putFields("__type__", MAX_VALUE_TYPE)) + .setMapValue(MapValue.newBuilder().putFields(TYPE_KEY, MAX_VALUE_TYPE)) + .build(); + + public static final Value VECTOR_VALUE_TYPE = + Value.newBuilder().setStringValue("__vector__").build(); + public static final String VECTOR_MAP_VECTORS_KEY = "value"; + private static final Value MIN_VECTOR_VALUE = + Value.newBuilder() + .setMapValue( + MapValue.newBuilder() + .putFields(TYPE_KEY, VECTOR_VALUE_TYPE) + .putFields( + VECTOR_MAP_VECTORS_KEY, + Value.newBuilder().setArrayValue(ArrayValue.newBuilder()).build())) .build(); /** @@ -62,7 +76,8 @@ public class Values { public static final int TYPE_ORDER_REFERENCE = 7; public static final int TYPE_ORDER_GEOPOINT = 8; public static final int TYPE_ORDER_ARRAY = 9; - public static final int TYPE_ORDER_MAP = 10; + public static final int TYPE_ORDER_VECTOR = 10; + public static final int TYPE_ORDER_MAP = 11; public static final int TYPE_ORDER_MAX_VALUE = Integer.MAX_VALUE; @@ -94,6 +109,8 @@ public static int typeOrder(Value value) { return TYPE_ORDER_SERVER_TIMESTAMP; } else if (isMaxValue(value)) { return TYPE_ORDER_MAX_VALUE; + } else if (isVectorValue(value)) { + return TYPE_ORDER_VECTOR; } else { return TYPE_ORDER_MAP; } @@ -122,6 +139,7 @@ public static boolean equals(Value left, Value right) { return numberEquals(left, right); case TYPE_ORDER_ARRAY: return arrayEquals(left, right); + case TYPE_ORDER_VECTOR: case TYPE_ORDER_MAP: return objectEquals(left, right); case TYPE_ORDER_SERVER_TIMESTAMP: @@ -223,6 +241,8 @@ public static int compare(Value left, Value right) { return compareArrays(left.getArrayValue(), right.getArrayValue()); case TYPE_ORDER_MAP: return compareMaps(left.getMapValue(), right.getMapValue()); + case TYPE_ORDER_VECTOR: + return compareVectors(left.getMapValue(), right.getMapValue()); default: throw fail("Invalid value type: " + leftType); } @@ -343,6 +363,23 @@ private static int compareMaps(MapValue left, MapValue right) { return Util.compareBooleans(iterator1.hasNext(), iterator2.hasNext()); } + private static int compareVectors(MapValue left, MapValue right) { + Map leftMap = left.getFieldsMap(); + Map rightMap = right.getFieldsMap(); + + // The vector is a map, but only vector value is compared. + ArrayValue leftArrayValue = leftMap.get(Values.VECTOR_MAP_VECTORS_KEY).getArrayValue(); + ArrayValue rightArrayValue = rightMap.get(Values.VECTOR_MAP_VECTORS_KEY).getArrayValue(); + + int lengthCompare = + Util.compareIntegers(leftArrayValue.getValuesCount(), rightArrayValue.getValuesCount()); + if (lengthCompare != 0) { + return lengthCompare; + } + + return compareArrays(leftArrayValue, rightArrayValue); + } + /** Generate the canonical ID for the provided field value (as used in Target serialization). */ public static String canonicalId(Value value) { StringBuilder builder = new StringBuilder(); @@ -482,70 +519,97 @@ public static Value refValue(DatabaseId databaseId, DocumentKey key) { return value; } + public static Value MIN_BOOLEAN = Value.newBuilder().setBooleanValue(false).build(); + public static Value MIN_NUMBER = Value.newBuilder().setDoubleValue(Double.NaN).build(); + public static Value MIN_TIMESTAMP = + Value.newBuilder() + .setTimestampValue(Timestamp.newBuilder().setSeconds(Long.MIN_VALUE)) + .build(); + public static Value MIN_STRING = Value.newBuilder().setStringValue("").build(); + public static Value MIN_BYTES = Value.newBuilder().setBytesValue(ByteString.EMPTY).build(); + public static Value MIN_REFERENCE = refValue(DatabaseId.EMPTY, DocumentKey.empty()); + public static Value MIN_GEO_POINT = + Value.newBuilder() + .setGeoPointValue(LatLng.newBuilder().setLatitude(-90.0).setLongitude(-180.0)) + .build(); + public static Value MIN_ARRAY = + Value.newBuilder().setArrayValue(ArrayValue.getDefaultInstance()).build(); + public static Value MIN_MAP = + Value.newBuilder().setMapValue(MapValue.getDefaultInstance()).build(); + /** Returns the lowest value for the given value type (inclusive). */ - public static Value getLowerBound(Value.ValueTypeCase valueTypeCase) { - switch (valueTypeCase) { + public static Value getLowerBound(Value value) { + switch (value.getValueTypeCase()) { case NULL_VALUE: return Values.NULL_VALUE; case BOOLEAN_VALUE: - return Value.newBuilder().setBooleanValue(false).build(); + return MIN_BOOLEAN; case INTEGER_VALUE: case DOUBLE_VALUE: - return Value.newBuilder().setDoubleValue(Double.NaN).build(); + return MIN_NUMBER; case TIMESTAMP_VALUE: - return Value.newBuilder() - .setTimestampValue(Timestamp.newBuilder().setSeconds(Long.MIN_VALUE)) - .build(); + return MIN_TIMESTAMP; case STRING_VALUE: - return Value.newBuilder().setStringValue("").build(); + return MIN_STRING; case BYTES_VALUE: - return Value.newBuilder().setBytesValue(ByteString.EMPTY).build(); + return MIN_BYTES; case REFERENCE_VALUE: - return refValue(DatabaseId.EMPTY, DocumentKey.empty()); + return MIN_REFERENCE; case GEO_POINT_VALUE: - return Value.newBuilder() - .setGeoPointValue(LatLng.newBuilder().setLatitude(-90.0).setLongitude(-180.0)) - .build(); + return MIN_GEO_POINT; case ARRAY_VALUE: - return Value.newBuilder().setArrayValue(ArrayValue.getDefaultInstance()).build(); + return MIN_ARRAY; case MAP_VALUE: - return Value.newBuilder().setMapValue(MapValue.getDefaultInstance()).build(); + // VectorValue sorts after ArrayValue and before an empty MapValue + if (isVectorValue(value)) { + return MIN_VECTOR_VALUE; + } + return MIN_MAP; default: - throw new IllegalArgumentException("Unknown value type: " + valueTypeCase); + throw new IllegalArgumentException("Unknown value type: " + value.getValueTypeCase()); } } /** Returns the largest value for the given value type (exclusive). */ - public static Value getUpperBound(Value.ValueTypeCase valueTypeCase) { - switch (valueTypeCase) { + public static Value getUpperBound(Value value) { + switch (value.getValueTypeCase()) { case NULL_VALUE: - return getLowerBound(Value.ValueTypeCase.BOOLEAN_VALUE); + return MIN_BOOLEAN; case BOOLEAN_VALUE: - return getLowerBound(Value.ValueTypeCase.INTEGER_VALUE); + return MIN_NUMBER; case INTEGER_VALUE: case DOUBLE_VALUE: - return getLowerBound(Value.ValueTypeCase.TIMESTAMP_VALUE); + return MIN_TIMESTAMP; case TIMESTAMP_VALUE: - return getLowerBound(Value.ValueTypeCase.STRING_VALUE); + return MIN_STRING; case STRING_VALUE: - return getLowerBound(Value.ValueTypeCase.BYTES_VALUE); + return MIN_BYTES; case BYTES_VALUE: - return getLowerBound(Value.ValueTypeCase.REFERENCE_VALUE); + return MIN_REFERENCE; case REFERENCE_VALUE: - return getLowerBound(Value.ValueTypeCase.GEO_POINT_VALUE); + return MIN_GEO_POINT; case GEO_POINT_VALUE: - return getLowerBound(Value.ValueTypeCase.ARRAY_VALUE); + return MIN_ARRAY; case ARRAY_VALUE: - return getLowerBound(Value.ValueTypeCase.MAP_VALUE); + return MIN_VECTOR_VALUE; case MAP_VALUE: + // VectorValue sorts after ArrayValue and before an empty MapValue + if (isVectorValue(value)) { + return MIN_MAP; + } return MAX_VALUE; default: - throw new IllegalArgumentException("Unknown value type: " + valueTypeCase); + throw new IllegalArgumentException("Unknown value type: " + value.getValueTypeCase()); } } /** Returns true if the Value represents the canonical {@link #MAX_VALUE} . */ public static boolean isMaxValue(Value value) { - return MAX_VALUE_TYPE.equals(value.getMapValue().getFieldsMap().get("__type__")); + return MAX_VALUE_TYPE.equals(value.getMapValue().getFieldsMap().get(TYPE_KEY)); + } + + /** Returns true if the Value represents a VectorValue . */ + public static boolean isVectorValue(Value value) { + return VECTOR_VALUE_TYPE.equals(value.getMapValue().getFieldsMap().get(TYPE_KEY)); } } diff --git a/firebase-firestore/src/main/java/com/google/firebase/firestore/util/CustomClassMapper.java b/firebase-firestore/src/main/java/com/google/firebase/firestore/util/CustomClassMapper.java index 9852524f51b..6e0df1e6d4a 100644 --- a/firebase-firestore/src/main/java/com/google/firebase/firestore/util/CustomClassMapper.java +++ b/firebase-firestore/src/main/java/com/google/firebase/firestore/util/CustomClassMapper.java @@ -29,6 +29,7 @@ import com.google.firebase.firestore.PropertyName; import com.google.firebase.firestore.ServerTimestamp; import com.google.firebase.firestore.ThrowOnExtraProperties; +import com.google.firebase.firestore.VectorValue; import java.lang.reflect.AccessibleObject; import java.lang.reflect.Constructor; import java.lang.reflect.Field; @@ -173,7 +174,8 @@ private static Object serialize(T o, ErrorPath path) { || o instanceof GeoPoint || o instanceof Blob || o instanceof DocumentReference - || o instanceof FieldValue) { + || o instanceof FieldValue + || o instanceof VectorValue) { return o; } else if (o instanceof Uri || o instanceof URI || o instanceof URL) { return o.toString(); @@ -241,6 +243,8 @@ private static T deserializeToClass(Object o, Class clazz, DeserializeCon return (T) convertGeoPoint(o, context); } else if (DocumentReference.class.isAssignableFrom(clazz)) { return (T) convertDocumentReference(o, context); + } else if (VectorValue.class.isAssignableFrom(clazz)) { + return (T) convertVectorValue(o, context); } else if (clazz.isArray()) { throw deserializeError( context.errorPath, "Converting to Arrays is not supported, please use Lists instead"); @@ -528,6 +532,16 @@ private static GeoPoint convertGeoPoint(Object o, DeserializeContext context) { } } + private static VectorValue convertVectorValue(Object o, DeserializeContext context) { + if (o instanceof VectorValue) { + return (VectorValue) o; + } else { + throw deserializeError( + context.errorPath, + "Failed to convert value of type " + o.getClass().getName() + " to VectorValue"); + } + } + private static DocumentReference convertDocumentReference(Object o, DeserializeContext context) { if (o instanceof DocumentReference) { return (DocumentReference) o; diff --git a/firebase-firestore/src/test/java/com/google/firebase/firestore/index/FirestoreIndexValueWriterTest.java b/firebase-firestore/src/test/java/com/google/firebase/firestore/index/FirestoreIndexValueWriterTest.java new file mode 100644 index 00000000000..6acb576666a --- /dev/null +++ b/firebase-firestore/src/test/java/com/google/firebase/firestore/index/FirestoreIndexValueWriterTest.java @@ -0,0 +1,103 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.firebase.firestore; + +import com.google.firebase.firestore.index.DirectionalIndexByteEncoder; +import com.google.firebase.firestore.index.FirestoreIndexValueWriter; +import com.google.firebase.firestore.index.IndexByteEncoder; +import com.google.firebase.firestore.model.DatabaseId; +import com.google.firebase.firestore.model.FieldIndex; +import com.google.firestore.v1.Value; +import java.util.concurrent.ExecutionException; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.robolectric.RobolectricTestRunner; + +@RunWith(RobolectricTestRunner.class) +public class FirestoreIndexValueWriterTest { + + @Test + public void writeIndexValueSupportsVector() throws ExecutionException, InterruptedException { + UserDataReader dataReader = new UserDataReader(DatabaseId.EMPTY); + Value value = dataReader.parseQueryValue(FieldValue.vector(new double[] {1, 2, 3})); + + IndexByteEncoder encoder = new IndexByteEncoder(); + FirestoreIndexValueWriter.INSTANCE.writeIndexValue( + value, encoder.forKind(FieldIndex.Segment.Kind.ASCENDING)); + byte[] actualBytes = encoder.getEncodedBytes(); + + IndexByteEncoder expectedEncoder = new IndexByteEncoder(); + DirectionalIndexByteEncoder expectedDirectionalEncoder = + expectedEncoder.forKind(FieldIndex.Segment.Kind.ASCENDING); + expectedDirectionalEncoder.writeLong( + FirestoreIndexValueWriter.INDEX_TYPE_VECTOR); // Vector type + expectedDirectionalEncoder.writeLong( + FirestoreIndexValueWriter.INDEX_TYPE_NUMBER); // Number type + expectedDirectionalEncoder.writeLong(3); // vector length + expectedDirectionalEncoder.writeLong( + FirestoreIndexValueWriter.INDEX_TYPE_STRING); // String type + expectedDirectionalEncoder.writeString("value"); // Vector value header + expectedDirectionalEncoder.writeLong(FirestoreIndexValueWriter.INDEX_TYPE_ARRAY); // Array type + expectedDirectionalEncoder.writeLong( + FirestoreIndexValueWriter.INDEX_TYPE_NUMBER); // Number type + expectedDirectionalEncoder.writeDouble(1); // position 0 + expectedDirectionalEncoder.writeLong( + FirestoreIndexValueWriter.INDEX_TYPE_NUMBER); // Number type + expectedDirectionalEncoder.writeDouble(2); // position 1 + expectedDirectionalEncoder.writeLong( + FirestoreIndexValueWriter.INDEX_TYPE_NUMBER); // Number type + expectedDirectionalEncoder.writeDouble(3); // position 2 + expectedDirectionalEncoder.writeLong( + FirestoreIndexValueWriter.NOT_TRUNCATED); // Array not truncated + expectedDirectionalEncoder.writeInfinity(); + byte[] expectedBytes = expectedEncoder.getEncodedBytes(); + + Assert.assertArrayEquals(actualBytes, expectedBytes); + } + + @Test + public void writeIndexValueSupportsEmptyVector() { + UserDataReader dataReader = new UserDataReader(DatabaseId.EMPTY); + Value value = dataReader.parseQueryValue(FieldValue.vector(new double[] {})); + + // Encode an actual VectorValue + IndexByteEncoder encoder = new IndexByteEncoder(); + FirestoreIndexValueWriter.INSTANCE.writeIndexValue( + value, encoder.forKind(FieldIndex.Segment.Kind.ASCENDING)); + byte[] actualBytes = encoder.getEncodedBytes(); + + // Create the expected representation of the encoded vector + IndexByteEncoder expectedEncoder = new IndexByteEncoder(); + DirectionalIndexByteEncoder expectedDirectionalEncoder = + expectedEncoder.forKind(FieldIndex.Segment.Kind.ASCENDING); + expectedDirectionalEncoder.writeLong( + FirestoreIndexValueWriter.INDEX_TYPE_VECTOR); // Vector type + expectedDirectionalEncoder.writeLong( + FirestoreIndexValueWriter.INDEX_TYPE_NUMBER); // Number type + expectedDirectionalEncoder.writeLong(0); // vector length + expectedDirectionalEncoder.writeLong( + FirestoreIndexValueWriter.INDEX_TYPE_STRING); // String type + expectedDirectionalEncoder.writeString("value"); // Vector value header + expectedDirectionalEncoder.writeLong(FirestoreIndexValueWriter.INDEX_TYPE_ARRAY); // Array type + expectedDirectionalEncoder.writeLong( + FirestoreIndexValueWriter.NOT_TRUNCATED); // Array not truncated + expectedDirectionalEncoder.writeInfinity(); + byte[] expectedBytes = expectedEncoder.getEncodedBytes(); + + // Assert actual and expected encodings are equal + Assert.assertArrayEquals(actualBytes, expectedBytes); + } +} diff --git a/firebase-firestore/src/test/java/com/google/firebase/firestore/local/SQLiteLocalStoreTest.java b/firebase-firestore/src/test/java/com/google/firebase/firestore/local/SQLiteLocalStoreTest.java index fa44664ad77..63569e6dc85 100644 --- a/firebase-firestore/src/test/java/com/google/firebase/firestore/local/SQLiteLocalStoreTest.java +++ b/firebase-firestore/src/test/java/com/google/firebase/firestore/local/SQLiteLocalStoreTest.java @@ -293,6 +293,80 @@ public void testUsesIndexForLimitQueryWhenIndexIsUpdated() { assertQueryReturned("coll/a", "coll/c"); } + @Test + public void testIndexesVectorValues() { + FieldIndex index = + fieldIndex( + "coll", 0, FieldIndex.INITIAL_STATE, "embedding", FieldIndex.Segment.Kind.ASCENDING); + configureFieldIndexes(singletonList(index)); + + writeMutation(setMutation("coll/arr1", map("embedding", Arrays.asList(0.1, 0.2, 0.3)))); + writeMutation(setMutation("coll/map2", map("embedding", map()))); + writeMutation( + setMutation("coll/doc3", map("embedding", FieldValue.vector(new double[] {4, 5, 6})))); + writeMutation(setMutation("coll/doc4", map("embedding", FieldValue.vector(new double[] {5})))); + + Query query = query("coll").orderBy(orderBy("embedding", "asc")); + executeQuery(query); + assertQueryReturned("coll/arr1", "coll/doc4", "coll/doc3", "coll/map2"); + + query = + query("coll").filter(filter("embedding", "==", FieldValue.vector(new double[] {4, 5, 6}))); + executeQuery(query); + assertQueryReturned("coll/doc3"); + + query = + query("coll").filter(filter("embedding", ">", FieldValue.vector(new double[] {4, 5, 6}))); + executeQuery(query); + assertQueryReturned(); + + query = query("coll").filter(filter("embedding", ">=", FieldValue.vector(new double[] {4}))); + executeQuery(query); + assertQueryReturned("coll/doc4", "coll/doc3"); + + backfillIndexes(); + + query = query("coll").orderBy(orderBy("embedding", "asc")); + executeQuery(query); + assertOverlaysRead(/* byKey= */ 4, /* byCollection= */ 0); + assertOverlayTypes( + keyMap( + "coll/arr1", + CountingQueryEngine.OverlayType.Set, + "coll/map2", + CountingQueryEngine.OverlayType.Set, + "coll/doc3", + CountingQueryEngine.OverlayType.Set, + "coll/doc4", + CountingQueryEngine.OverlayType.Set)); + assertQueryReturned("coll/arr1", "coll/doc4", "coll/doc3", "coll/map2"); + + query = + query("coll").filter(filter("embedding", "==", FieldValue.vector(new double[] {4, 5, 6}))); + executeQuery(query); + assertOverlaysRead(/* byKey= */ 1, /* byCollection= */ 0); + assertOverlayTypes(keyMap("coll/doc3", CountingQueryEngine.OverlayType.Set)); + assertQueryReturned("coll/doc3"); + + query = + query("coll").filter(filter("embedding", ">", FieldValue.vector(new double[] {4, 5, 6}))); + executeQuery(query); + assertOverlaysRead(/* byKey= */ 0, /* byCollection= */ 0); + assertOverlayTypes(keyMap()); + assertQueryReturned(); + + query = query("coll").filter(filter("embedding", ">=", FieldValue.vector(new double[] {4}))); + executeQuery(query); + assertOverlaysRead(/* byKey= */ 2, /* byCollection= */ 0); + assertOverlayTypes( + keyMap( + "coll/doc4", + CountingQueryEngine.OverlayType.Set, + "coll/doc3", + CountingQueryEngine.OverlayType.Set)); + assertQueryReturned("coll/doc4", "coll/doc3"); + } + @Test public void testIndexesServerTimestamps() { FieldIndex index = diff --git a/firebase-firestore/src/test/java/com/google/firebase/firestore/model/ValuesTest.java b/firebase-firestore/src/test/java/com/google/firebase/firestore/model/ValuesTest.java index 267197bc096..6a7dbe9c259 100644 --- a/firebase-firestore/src/test/java/com/google/firebase/firestore/model/ValuesTest.java +++ b/firebase-firestore/src/test/java/com/google/firebase/firestore/model/ValuesTest.java @@ -26,6 +26,7 @@ import com.google.common.testing.EqualsTester; import com.google.firebase.Timestamp; +import com.google.firebase.firestore.FieldValue; import com.google.firebase.firestore.GeoPoint; import com.google.firebase.firestore.testutil.ComparatorTester; import com.google.firebase.firestore.testutil.TestUtil; @@ -34,6 +35,7 @@ import java.util.Calendar; import java.util.Collections; import java.util.Date; +import java.util.LinkedList; import java.util.TimeZone; import org.junit.Test; import org.junit.runner.RunWith; @@ -100,6 +102,8 @@ public void testValueEquality() { .addEqualityGroup(wrap(Arrays.asList("foo", "bar")), wrap(Arrays.asList("foo", "bar"))) .addEqualityGroup(wrap(Arrays.asList("foo", "bar", "baz"))) .addEqualityGroup(wrap(Arrays.asList("foo"))) + .addEqualityGroup(wrap(FieldValue.vector(new double[] {}))) + .addEqualityGroup(wrap(FieldValue.vector(new double[] {1, 2.3, -4}))) .addEqualityGroup(wrap(map("bar", 1, "foo", 2)), wrap(map("foo", 2, "bar", 1))) .addEqualityGroup(wrap(map("bar", 2, "foo", 1))) .addEqualityGroup(wrap(map("bar", 1))) @@ -196,6 +200,12 @@ public void testValueOrdering() { .addEqualityGroup(wrap(Arrays.asList("foo", 2))) .addEqualityGroup(wrap(Arrays.asList("foo", "0"))) + // vector + .addEqualityGroup(wrap(FieldValue.vector(new double[] {}))) + .addEqualityGroup(wrap(FieldValue.vector(new double[] {100}))) + .addEqualityGroup(wrap(FieldValue.vector(new double[] {1, 2, 3}))) + .addEqualityGroup(wrap(FieldValue.vector(new double[] {1, 3, 2}))) + // objects .addEqualityGroup(wrap(map("bar", 0))) .addEqualityGroup(wrap(map("bar", 0, "foo", 1))) @@ -209,47 +219,58 @@ public void testValueOrdering() { public void testLowerBound() { new ComparatorTester() // null first - .addEqualityGroup(wrap(getLowerBound(Value.ValueTypeCase.NULL_VALUE)), wrap((Object) null)) + .addEqualityGroup(wrap(getLowerBound(TestUtil.wrap((Object) null))), wrap((Object) null)) // booleans - .addEqualityGroup(wrap(false), wrap(getLowerBound(Value.ValueTypeCase.BOOLEAN_VALUE))) + .addEqualityGroup(wrap(false), wrap(getLowerBound(TestUtil.wrap(true)))) .addEqualityGroup(wrap(true)) // numbers - .addEqualityGroup(wrap(getLowerBound(Value.ValueTypeCase.DOUBLE_VALUE)), wrap(Double.NaN)) + .addEqualityGroup(wrap(getLowerBound(TestUtil.wrap(1.0))), wrap(Double.NaN)) .addEqualityGroup(wrap(Double.NEGATIVE_INFINITY)) .addEqualityGroup(wrap(Long.MIN_VALUE)) // dates - .addEqualityGroup(wrap(getLowerBound(Value.ValueTypeCase.TIMESTAMP_VALUE))) + .addEqualityGroup(wrap(getLowerBound(TestUtil.wrap(date1)))) .addEqualityGroup(wrap(date1)) // strings - .addEqualityGroup(wrap(getLowerBound(Value.ValueTypeCase.STRING_VALUE)), wrap("")) + .addEqualityGroup(wrap(getLowerBound(TestUtil.wrap("foo"))), wrap("")) .addEqualityGroup(wrap("\000")) // blobs - .addEqualityGroup(wrap(getLowerBound(Value.ValueTypeCase.BYTES_VALUE)), wrap(blob())) + .addEqualityGroup(wrap(getLowerBound(TestUtil.wrap(blob(1, 2, 3)))), wrap(blob())) .addEqualityGroup(wrap(blob(0))) // resource names .addEqualityGroup( - wrap(getLowerBound(Value.ValueTypeCase.REFERENCE_VALUE)), + wrap(getLowerBound(wrapRef(dbId("foo", "bar"), key("x/y")))), wrap(wrapRef(dbId("", ""), key("")))) .addEqualityGroup(wrap(wrapRef(dbId("", ""), key("a/a")))) // geo points .addEqualityGroup( - wrap(getLowerBound(Value.ValueTypeCase.GEO_POINT_VALUE)), wrap(new GeoPoint(-90, -180))) + wrap(getLowerBound(TestUtil.wrap(new GeoPoint(-90, 0)))), wrap(new GeoPoint(-90, -180))) .addEqualityGroup(wrap(new GeoPoint(-90, 0))) // arrays .addEqualityGroup( - wrap(getLowerBound(Value.ValueTypeCase.ARRAY_VALUE)), wrap(Collections.emptyList())) + wrap(getLowerBound(TestUtil.wrap(Collections.singletonList(false)))), + wrap(Collections.emptyList())) .addEqualityGroup(wrap(Collections.singletonList(false))) + // vectors + .addEqualityGroup( + wrap( + getLowerBound( + TestUtil.wrap( + map("__type__", "__vector__", "value", Collections.singletonList(1.0))))), + wrap(map("__type__", "__vector__", "value", new LinkedList()))) + .addEqualityGroup( + wrap(map("__type__", "__vector__", "value", Collections.singletonList(1.0)))) + // objects - .addEqualityGroup(wrap(getLowerBound(Value.ValueTypeCase.MAP_VALUE)), wrap(map())) + .addEqualityGroup(wrap(getLowerBound(TestUtil.wrap(map("foo", "bar")))), wrap(map())) .testCompare(); } @@ -258,43 +279,53 @@ public void testUpperBound() { new ComparatorTester() // null first .addEqualityGroup(wrap((Object) null)) - .addEqualityGroup(wrap(getUpperBound(Value.ValueTypeCase.NULL_VALUE))) + .addEqualityGroup(wrap(getUpperBound(TestUtil.wrap((Object) null)))) // booleans .addEqualityGroup(wrap(true)) - .addEqualityGroup(wrap(getUpperBound(Value.ValueTypeCase.BOOLEAN_VALUE))) + .addEqualityGroup(wrap(getUpperBound(TestUtil.wrap(false)))) // numbers .addEqualityGroup(wrap(Long.MAX_VALUE)) .addEqualityGroup(wrap(Double.POSITIVE_INFINITY)) - .addEqualityGroup(wrap(getUpperBound(Value.ValueTypeCase.DOUBLE_VALUE))) + .addEqualityGroup(wrap(getUpperBound(TestUtil.wrap(1.0)))) // dates .addEqualityGroup(wrap(date1)) - .addEqualityGroup(wrap(getUpperBound(Value.ValueTypeCase.TIMESTAMP_VALUE))) + .addEqualityGroup(wrap(getUpperBound(TestUtil.wrap(date1)))) // strings .addEqualityGroup(wrap("\000")) - .addEqualityGroup(wrap(getUpperBound(Value.ValueTypeCase.STRING_VALUE))) + .addEqualityGroup(wrap(getUpperBound(TestUtil.wrap("\000")))) // blobs .addEqualityGroup(wrap(blob(255))) - .addEqualityGroup(wrap(getUpperBound(Value.ValueTypeCase.BYTES_VALUE))) + .addEqualityGroup(wrap(getUpperBound(TestUtil.wrap(blob(255))))) // resource names .addEqualityGroup(wrap(wrapRef(dbId("", ""), key("a/a")))) - .addEqualityGroup(wrap(getUpperBound(Value.ValueTypeCase.REFERENCE_VALUE))) + .addEqualityGroup(wrap(getUpperBound(wrapRef(dbId("", ""), key("a/a"))))) // geo points .addEqualityGroup(wrap(new GeoPoint(90, 180))) - .addEqualityGroup(wrap(getUpperBound(Value.ValueTypeCase.GEO_POINT_VALUE))) + .addEqualityGroup(wrap(getUpperBound(TestUtil.wrap(new GeoPoint(90, 180))))) // arrays .addEqualityGroup(wrap(Collections.singletonList(false))) - .addEqualityGroup(wrap(getUpperBound(Value.ValueTypeCase.ARRAY_VALUE))) + .addEqualityGroup(wrap(getUpperBound(TestUtil.wrap(Collections.singletonList(false))))) + + // vectors + .addEqualityGroup( + wrap(map("__type__", "__vector__", "value", Collections.singletonList(1.0)))) + .addEqualityGroup( + wrap( + getUpperBound( + TestUtil.wrap( + map("__type__", "__vector__", "value", Collections.singletonList(1.0)))))) // objects .addEqualityGroup(wrap(map("a", "b"))) + .addEqualityGroup(wrap(getUpperBound(TestUtil.wrap(map("a", "b"))))) .testCompare(); } diff --git a/firebase-firestore/src/test/java/com/google/firebase/firestore/remote/RemoteSerializerTest.java b/firebase-firestore/src/test/java/com/google/firebase/firestore/remote/RemoteSerializerTest.java index 30e5715613b..52eec0ac4cd 100644 --- a/firebase-firestore/src/test/java/com/google/firebase/firestore/remote/RemoteSerializerTest.java +++ b/firebase-firestore/src/test/java/com/google/firebase/firestore/remote/RemoteSerializerTest.java @@ -40,6 +40,7 @@ import static org.junit.Assert.assertTrue; import com.google.firebase.firestore.DocumentReference; +import com.google.firebase.firestore.FieldValue; import com.google.firebase.firestore.GeoPoint; import com.google.firebase.firestore.core.ArrayContainsAnyFilter; import com.google.firebase.firestore.core.FieldFilter; @@ -308,6 +309,26 @@ public void testEncodesNestedObjects() { assertRoundTrip(model.get(FieldPath.EMPTY_PATH), proto, Value.ValueTypeCase.MAP_VALUE); } + @Test + public void testEncodesVectorValue() { + Value model = wrap(FieldValue.vector(new double[] {1, 2, 3})); + + ArrayValue.Builder array = + ArrayValue.newBuilder() + .addValues(Value.newBuilder().setDoubleValue(1)) + .addValues(Value.newBuilder().setDoubleValue(2)) + .addValues(Value.newBuilder().setDoubleValue(3)); + + MapValue.Builder obj = + MapValue.newBuilder() + .putFields("__type__", Value.newBuilder().setStringValue("__vector__").build()) + .putFields("value", Value.newBuilder().setArrayValue(array).build()); + + Value proto = Value.newBuilder().setMapValue(obj).build(); + + assertRoundTrip(model, proto, Value.ValueTypeCase.MAP_VALUE); + } + @Test public void testEncodeDeleteMutation() { Mutation mutation = deleteMutation("docs/1");