diff --git a/google-cloud-clients/google-cloud-firestore/src/main/java/com/google/cloud/firestore/Firestore.java b/google-cloud-clients/google-cloud-firestore/src/main/java/com/google/cloud/firestore/Firestore.java index cf1be3eb3826..26ee155ea599 100644 --- a/google-cloud-clients/google-cloud-firestore/src/main/java/com/google/cloud/firestore/Firestore.java +++ b/google-cloud-clients/google-cloud-firestore/src/main/java/com/google/cloud/firestore/Firestore.java @@ -17,6 +17,7 @@ package com.google.cloud.firestore; import com.google.api.core.ApiFuture; +import com.google.api.gax.rpc.ApiStreamObserver; import com.google.cloud.Service; import java.util.List; import javax.annotation.Nonnull; @@ -117,6 +118,20 @@ ApiFuture runTransaction( ApiFuture> getAll( @Nonnull DocumentReference[] documentReferences, @Nullable FieldMask fieldMask); + /** + * Retrieves multiple documents from Firestore while optionally applying a field mask to reduce + * the amount of data transmitted. Returned documents will be out of order. + * + * @param documentReferences Array with Document References to fetch. + * @param fieldMask If not null, specifies the subset of fields to return. + * @param responseObserver The observer to be notified when {@link DocumentSnapshot} details + * arrive. + */ + void getAll( + @Nonnull DocumentReference[] documentReferences, + @Nullable FieldMask fieldMask, + final ApiStreamObserver responseObserver); + /** * Gets a Firestore {@link WriteBatch} instance that can be used to combine multiple writes. * diff --git a/google-cloud-clients/google-cloud-firestore/src/main/java/com/google/cloud/firestore/FirestoreImpl.java b/google-cloud-clients/google-cloud-firestore/src/main/java/com/google/cloud/firestore/FirestoreImpl.java index 77a29901c995..cf9bfdb483c8 100644 --- a/google-cloud-clients/google-cloud-firestore/src/main/java/com/google/cloud/firestore/FirestoreImpl.java +++ b/google-cloud-clients/google-cloud-firestore/src/main/java/com/google/cloud/firestore/FirestoreImpl.java @@ -140,23 +140,30 @@ public Iterable getCollections() { @Override public ApiFuture> getAll( @Nonnull DocumentReference... documentReferences) { - return this.getAll(documentReferences, null, null); + return this.getAll(documentReferences, null, (ByteString) null); } @Nonnull @Override public ApiFuture> getAll( @Nonnull DocumentReference[] documentReferences, @Nullable FieldMask fieldMask) { - return this.getAll(documentReferences, fieldMask, null); + return this.getAll(documentReferences, fieldMask, (ByteString) null); } - /** Internal getAll() method that accepts an optional transaction id. */ - ApiFuture> getAll( - final DocumentReference[] documentReferences, + @Nonnull + @Override + public void getAll( + final @Nonnull DocumentReference[] documentReferences, @Nullable FieldMask fieldMask, - @Nullable ByteString transactionId) { - final SettableApiFuture> futureList = SettableApiFuture.create(); - final Map resultMap = new HashMap<>(); + @Nonnull final ApiStreamObserver apiStreamObserver) { + this.getAll(documentReferences, fieldMask, null, apiStreamObserver); + } + + void getAll( + final @Nonnull DocumentReference[] documentReferences, + @Nullable FieldMask fieldMask, + @Nullable ByteString transactionId, + final ApiStreamObserver apiStreamObserver) { ApiStreamObserver responseObserver = new ApiStreamObserver() { @@ -176,9 +183,6 @@ public void onNext(BatchGetDocumentsResponse response) { switch (response.getResultCase()) { case FOUND: - documentReference = - new DocumentReference( - FirestoreImpl.this, ResourcePath.create(response.getFound().getName())); documentSnapshot = DocumentSnapshot.fromDocument( FirestoreImpl.this, @@ -198,26 +202,19 @@ public void onNext(BatchGetDocumentsResponse response) { default: return; } - - resultMap.put(documentReference, documentSnapshot); + apiStreamObserver.onNext(documentSnapshot); } @Override public void onError(Throwable throwable) { tracer.getCurrentSpan().addAnnotation("Firestore.BatchGet: Error"); - futureList.setException(throwable); + apiStreamObserver.onError(throwable); } @Override public void onCompleted() { tracer.getCurrentSpan().addAnnotation("Firestore.BatchGet: Complete"); - List documentSnapshots = new ArrayList<>(); - - for (DocumentReference documentReference : documentReferences) { - documentSnapshots.add(resultMap.get(documentReference)); - } - - futureList.set(documentSnapshots); + apiStreamObserver.onCompleted(); } }; @@ -244,7 +241,39 @@ public void onCompleted() { "numDocuments", AttributeValue.longAttributeValue(documentReferences.length))); streamRequest(request.build(), responseObserver, firestoreClient.batchGetDocumentsCallable()); + } + /** Internal getAll() method that accepts an optional transaction id. */ + ApiFuture> getAll( + final @Nonnull DocumentReference[] documentReferences, + @Nullable FieldMask fieldMask, + @Nullable ByteString transactionId) { + final SettableApiFuture> futureList = SettableApiFuture.create(); + final Map documentSnapshotMap = new HashMap<>(); + getAll( + documentReferences, + fieldMask, + transactionId, + new ApiStreamObserver() { + @Override + public void onNext(DocumentSnapshot documentSnapshot) { + documentSnapshotMap.put(documentSnapshot.getReference(), documentSnapshot); + } + + @Override + public void onError(Throwable throwable) { + futureList.setException(throwable); + } + + @Override + public void onCompleted() { + List documentSnapshotsList = new ArrayList<>(); + for (DocumentReference documentReference : documentReferences) { + documentSnapshotsList.add(documentSnapshotMap.get(documentReference)); + } + futureList.set(documentSnapshotsList); + } + }); return futureList; } diff --git a/google-cloud-clients/google-cloud-firestore/src/test/java/com/google/cloud/firestore/it/ITSystemTest.java b/google-cloud-clients/google-cloud-firestore/src/test/java/com/google/cloud/firestore/it/ITSystemTest.java index de19f5492da7..eb9e56888efa 100644 --- a/google-cloud-clients/google-cloud-firestore/src/test/java/com/google/cloud/firestore/it/ITSystemTest.java +++ b/google-cloud-clients/google-cloud-firestore/src/test/java/com/google/cloud/firestore/it/ITSystemTest.java @@ -22,6 +22,7 @@ import static java.util.Arrays.asList; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; @@ -29,6 +30,8 @@ import com.google.api.core.ApiFuture; import com.google.api.core.ApiFutures; +import com.google.api.core.SettableApiFuture; +import com.google.api.gax.rpc.ApiStreamObserver; import com.google.cloud.Timestamp; import com.google.cloud.firestore.CollectionReference; import com.google.cloud.firestore.DocumentChange; @@ -1327,4 +1330,53 @@ private static final class ListenerEvent { this.error = error; } } + + @Test + public void getAllWithObserver() throws Exception { + DocumentReference ref1 = randomColl.document("doc1"); + ref1.set(ALL_SUPPORTED_TYPES_MAP).get(); + + DocumentReference ref2 = randomColl.document("doc2"); + ref2.set(ALL_SUPPORTED_TYPES_MAP).get(); + + DocumentReference ref3 = randomColl.document("doc3"); + + final List documentSnapshots = + Collections.synchronizedList(new ArrayList()); + final DocumentReference[] documentReferences = {ref1, ref2, ref3}; + final SettableApiFuture future = SettableApiFuture.create(); + firestore.getAll( + documentReferences, + FieldMask.of("foo"), + new ApiStreamObserver() { + + @Override + public void onNext(DocumentSnapshot documentSnapshot) { + documentSnapshots.add(documentSnapshot); + } + + @Override + public void onError(Throwable throwable) { + future.setException(throwable); + } + + @Override + public void onCompleted() { + future.set(null); + } + }); + + future.get(); + + assertEquals( + ALL_SUPPORTED_TYPES_OBJECT, documentSnapshots.get(0).toObject(AllSupportedTypes.class)); + assertEquals( + ALL_SUPPORTED_TYPES_OBJECT, documentSnapshots.get(1).toObject(AllSupportedTypes.class)); + assertNotEquals( + ALL_SUPPORTED_TYPES_OBJECT, documentSnapshots.get(2).toObject(AllSupportedTypes.class)); + assertEquals(ref1.getId(), documentSnapshots.get(0).getId()); + assertEquals(ref2.getId(), documentSnapshots.get(1).getId()); + assertEquals(ref3.getId(), documentSnapshots.get(2).getId()); + assertEquals(3, documentSnapshots.size()); + } }