Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@
import com.google.api.core.NanoClock;
import com.google.api.core.SettableApiFuture;
import com.google.api.gax.rpc.ApiStreamObserver;
import com.google.api.gax.rpc.BidiStreamObserver;
import com.google.api.gax.rpc.BidiStreamingCallable;
import com.google.api.gax.rpc.ClientStream;
import com.google.api.gax.rpc.ResponseObserver;
import com.google.api.gax.rpc.ServerStreamingCallable;
import com.google.api.gax.rpc.StreamController;
import com.google.api.gax.rpc.UnaryCallable;
import com.google.cloud.Timestamp;
import com.google.cloud.firestore.spi.v1.FirestoreRpc;
Expand Down Expand Up @@ -201,7 +205,6 @@ public ApiFuture<List<DocumentSnapshot>> getAll(
return this.getAll(documentReferences, fieldMask, (ByteString) null);
}

@Nonnull
@Override
public void getAll(
final @Nonnull DocumentReference[] documentReferences,
Expand All @@ -216,12 +219,15 @@ void getAll(
@Nullable ByteString transactionId,
final ApiStreamObserver<DocumentSnapshot> apiStreamObserver) {

ApiStreamObserver<BatchGetDocumentsResponse> responseObserver =
new ApiStreamObserver<BatchGetDocumentsResponse>() {
ResponseObserver<BatchGetDocumentsResponse> responseObserver =
new ResponseObserver<BatchGetDocumentsResponse>() {
int numResponses;

@Override
public void onNext(BatchGetDocumentsResponse response) {
public void onStart(StreamController streamController) {}

@Override
public void onResponse(BatchGetDocumentsResponse response) {
DocumentReference documentReference;
DocumentSnapshot documentSnapshot;

Expand Down Expand Up @@ -270,7 +276,7 @@ public void onError(Throwable throwable) {
}

@Override
public void onCompleted() {
public void onComplete() {
tracer
.getCurrentSpan()
.addAnnotation(TraceUtil.SPAN_NAME_BATCHGETDOCUMENTS + ": Complete");
Expand Down Expand Up @@ -433,19 +439,19 @@ public <RequestT, ResponseT> ApiFuture<ResponseT> sendRequest(
@Override
public <RequestT, ResponseT> void streamRequest(
RequestT requestT,
ApiStreamObserver<ResponseT> responseObserverT,
ResponseObserver<ResponseT> responseObserverT,
ServerStreamingCallable<RequestT, ResponseT> callable) {
Preconditions.checkState(!closed, "Firestore client has already been closed");
callable.serverStreamingCall(requestT, responseObserverT);
callable.call(requestT, responseObserverT);
}

/** Request funnel for all bidirectional streaming requests. */
@Override
public <RequestT, ResponseT> ApiStreamObserver<RequestT> streamRequest(
ApiStreamObserver<ResponseT> responseObserverT,
public <RequestT, ResponseT> ClientStream<RequestT> streamRequest(
BidiStreamObserver<RequestT, ResponseT> responseObserverT,
BidiStreamingCallable<RequestT, ResponseT> callable) {
Preconditions.checkState(!closed, "Firestore client has already been closed");
return callable.bidiStreamingCall(responseObserverT);
return callable.splitCall(responseObserverT);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import com.google.api.core.ApiFuture;
import com.google.api.core.InternalApi;
import com.google.api.core.InternalExtensionOnly;
import com.google.api.gax.rpc.ApiStreamObserver;
import com.google.api.gax.rpc.BidiStreamObserver;
import com.google.api.gax.rpc.BidiStreamingCallable;
import com.google.api.gax.rpc.ClientStream;
import com.google.api.gax.rpc.ResponseObserver;
import com.google.api.gax.rpc.ServerStreamingCallable;
import com.google.api.gax.rpc.UnaryCallable;
import com.google.cloud.firestore.spi.v1.FirestoreRpc;
Expand All @@ -48,10 +50,10 @@ <RequestT, ResponseT> ApiFuture<ResponseT> sendRequest(

<RequestT, ResponseT> void streamRequest(
RequestT requestT,
ApiStreamObserver<ResponseT> responseObserverT,
ResponseObserver<ResponseT> responseObserverT,
ServerStreamingCallable<RequestT, ResponseT> callable);

<RequestT, ResponseT> ApiStreamObserver<RequestT> streamRequest(
ApiStreamObserver<ResponseT> responseObserverT,
<RequestT, ResponseT> ClientStream<RequestT> streamRequest(
BidiStreamObserver<RequestT, ResponseT> responseObserverT,
Comment on lines +53 to +57
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are the main changes, the rest is plumbing and cleanup.

BidiStreamingCallable<RequestT, ResponseT> callable);
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
import com.google.api.core.InternalExtensionOnly;
import com.google.api.core.SettableApiFuture;
import com.google.api.gax.rpc.ApiStreamObserver;
import com.google.api.gax.rpc.ResponseObserver;
import com.google.api.gax.rpc.StatusCode;
import com.google.api.gax.rpc.StreamController;
import com.google.auto.value.AutoValue;
import com.google.cloud.Timestamp;
import com.google.cloud.firestore.Query.QueryOptions.Builder;
Expand Down Expand Up @@ -1502,14 +1504,17 @@ private void internalStream(

final AtomicReference<QueryDocumentSnapshot> lastReceivedDocument = new AtomicReference<>();

ApiStreamObserver<RunQueryResponse> observer =
new ApiStreamObserver<RunQueryResponse>() {
ResponseObserver<RunQueryResponse> observer =
new ResponseObserver<RunQueryResponse>() {
Timestamp readTime;
boolean firstResponse;
int numDocuments;

@Override
public void onNext(RunQueryResponse response) {
public void onStart(StreamController streamController) {}

@Override
public void onResponse(RunQueryResponse response) {
if (!firstResponse) {
firstResponse = true;
Tracing.getTracer().getCurrentSpan().addAnnotation("Firestore.Query: First response");
Expand Down Expand Up @@ -1557,7 +1562,7 @@ public void onError(Throwable throwable) {
}

@Override
public void onCompleted() {
public void onComplete() {
Tracing.getTracer()
.getCurrentSpan()
.addAnnotation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import com.google.api.gax.retrying.ExponentialRetryAlgorithm;
import com.google.api.gax.retrying.TimedAttemptSettings;
import com.google.api.gax.rpc.ApiException;
import com.google.api.gax.rpc.ApiStreamObserver;
import com.google.api.gax.rpc.BidiStreamObserver;
import com.google.api.gax.rpc.ClientStream;
import com.google.api.gax.rpc.StreamController;
import com.google.cloud.Timestamp;
import com.google.cloud.firestore.DocumentChange.Type;
import com.google.common.base.Preconditions;
Expand Down Expand Up @@ -57,7 +59,7 @@
* It synchronizes on its own instance so it is advisable not to use this class for external
* synchronization.
*/
class Watch implements ApiStreamObserver<ListenResponse> {
class Watch implements BidiStreamObserver<ListenRequest, ListenResponse> {
/**
* Target ID used by watch. Watch uses a fixed target id since we only support one target per
* stream. The actual target ID we use is arbitrary.
Expand All @@ -71,7 +73,7 @@ class Watch implements ApiStreamObserver<ListenResponse> {
private final ExponentialRetryAlgorithm backoff;
private final Target target;
private TimedAttemptSettings nextAttempt;
private ApiStreamObserver<ListenRequest> stream;
private ClientStream<ListenRequest> stream;

/** The sorted tree of DocumentSnapshots as sent in the last snapshot. */
private DocumentSet documentSet;
Expand Down Expand Up @@ -167,7 +169,13 @@ static Watch forQuery(Query query) {
}

@Override
public synchronized void onNext(ListenResponse listenResponse) {
public void onStart(StreamController streamController) {}

@Override
public void onReady(ClientStream<ListenRequest> clientStream) {}

@Override
public synchronized void onResponse(ListenResponse listenResponse) {
switch (listenResponse.getResponseTypeCase()) {
case TARGET_CHANGE:
TargetChange change = listenResponse.getTargetChange();
Expand Down Expand Up @@ -258,7 +266,7 @@ public synchronized void onError(Throwable throwable) {
}

@Override
public synchronized void onCompleted() {
public synchronized void onComplete() {
maybeReopenStream(new StatusException(Status.fromCode(Code.UNKNOWN)));
}

Expand Down Expand Up @@ -289,7 +297,7 @@ ListenerRegistration runWatch(
.execute(
() -> {
synchronized (Watch.this) {
stream.onCompleted();
stream.closeSend();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This now closes the stream instead of invoking a no-op method.

stream = null;
}
});
Expand Down Expand Up @@ -321,7 +329,7 @@ private void resetDocs() {
/** Closes the stream and calls onError() if the stream is still active. */
private void closeStream(final Throwable throwable) {
if (stream != null) {
stream.onCompleted();
stream.closeSend();
stream = null;
}

Expand Down Expand Up @@ -363,7 +371,7 @@ private void maybeReopenStream(Throwable throwable) {
/** Helper to restart the outgoing stream to the backend. */
private void resetStream() {
if (stream != null) {
stream.onCompleted();
stream.closeSend();
stream = null;
}

Expand Down Expand Up @@ -399,7 +407,7 @@ private void initStream() {
request.getAddTargetBuilder().setResumeToken(resumeToken);
}

stream.onNext(request.build());
stream.send(request.build());
}
} catch (Throwable throwable) {
onError(throwable);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@
import org.mockito.Mockito;
import org.mockito.Spy;
import org.mockito.runners.MockitoJUnitRunner;
import org.mockito.stubbing.Answer;

@RunWith(MockitoJUnitRunner.class)
public class BulkWriterTest {
Expand Down Expand Up @@ -1119,11 +1118,10 @@ public ScheduledFuture<?> schedule(Runnable command, long delay, TimeUnit unit)
};

doAnswer(
(Answer<ApiFuture<GeneratedMessageV3>>)
mock -> {
retryAttempts[0]++;
return RETRYABLE_FAILED_FUTURE;
})
mock -> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this test the new changes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's hard to test these changes. We could mock the backend stream, but then we are essentially only testing the behavior of the mock. If you know of a pre-existing implementation/fake of a GRPC stream that we can use to test this behavior, then we can add a test. A homegrown implementation that validates that our code follows our ow assumptions will not provide us with much meaningful test coverage.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added test queryWatchShutsDownStreamOnPermissionDenied() that re-uses some of the existing functionality.

retryAttempts[0]++;
return RETRYABLE_FAILED_FUTURE;
})
.when(firestoreMock)
.sendRequest(
batchWriteCapture.capture(),
Expand Down Expand Up @@ -1170,11 +1168,10 @@ public ScheduledFuture<?> schedule(Runnable command, long delay, TimeUnit unit)
};

doAnswer(
(Answer<ApiFuture<GeneratedMessageV3>>)
mock -> {
retryAttempts[0]++;
return RESOURCE_EXHAUSTED_FAILED_FUTURE;
})
mock -> {
retryAttempts[0]++;
return RESOURCE_EXHAUSTED_FAILED_FUTURE;
})
.when(firestoreMock)
.sendRequest(
batchWriteCapture.capture(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@

import com.google.api.core.ApiFuture;
import com.google.api.core.SettableApiFuture;
import com.google.api.gax.rpc.ApiStreamObserver;
import com.google.api.gax.rpc.BidiStreamObserver;
import com.google.api.gax.rpc.BidiStreamingCallable;
import com.google.api.gax.rpc.ClientStream;
import com.google.api.gax.rpc.ResponseObserver;
import com.google.api.gax.rpc.ServerStreamingCallable;
import com.google.api.gax.rpc.UnaryCallable;
import com.google.cloud.Timestamp;
Expand Down Expand Up @@ -63,7 +65,6 @@
import com.google.firestore.v1.ListenRequest;
import com.google.firestore.v1.ListenResponse;
import com.google.firestore.v1.RunQueryRequest;
import com.google.protobuf.AbstractMessage;
import com.google.protobuf.Message;
import com.google.protobuf.util.JsonFormat;
import java.io.IOException;
Expand All @@ -90,7 +91,6 @@
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.mockito.stubbing.Answer;

@RunWith(AllTests.class)
public class ConformanceTest {
Expand Down Expand Up @@ -251,7 +251,7 @@ final CollectionReference collection(final String absolutePath) {
private static final class ConformanceGetTestRunner extends BaseConformanceTestRunner<GetTest> {

@Captor private ArgumentCaptor<BatchGetDocumentsRequest> getAllCapture;
@Captor private ArgumentCaptor<ApiStreamObserver<AbstractMessage>> streamObserverCapture;
@Captor private ArgumentCaptor<ResponseObserver<Message>> streamObserverCapture;

private ConformanceGetTestRunner(String description, GetTest testParameters) {
super(description, testParameters);
Expand Down Expand Up @@ -480,7 +480,7 @@ private static final class ConformanceQueryTestRunner
extends BaseConformanceTestRunner<TestDefinition.QueryTest> {

@Captor private ArgumentCaptor<RunQueryRequest> runQueryCapture;
@Captor private ArgumentCaptor<ApiStreamObserver<AbstractMessage>> streamObserverCapture;
@Captor private ArgumentCaptor<ResponseObserver<Message>> streamObserverCapture;

private ConformanceQueryTestRunner(
String description, TestDefinition.QueryTest testParameters) {
Expand Down Expand Up @@ -624,8 +624,8 @@ private DocumentSnapshot convertDocument(DocSnapshot snapshot) {
private static final class ConformanceListenTestRunner
extends BaseConformanceTestRunner<ListenTest> {

@Captor private ArgumentCaptor<ApiStreamObserver<AbstractMessage>> streamObserverCapture;
@Mock private ApiStreamObserver<ListenRequest> noOpRequestObserver;
@Captor private ArgumentCaptor<BidiStreamObserver<Message, Message>> streamObserverCapture;
@Mock private ClientStream<ListenRequest> noOpRequestObserver;

private final Query watchQuery;

Expand All @@ -641,11 +641,10 @@ public void runTest() throws Throwable {
final SettableApiFuture<Void> testCaseFinished = SettableApiFuture.create();

doAnswer(
(Answer<ApiStreamObserver<ListenRequest>>)
invocationOnMock -> {
testCaseStarted.set(null);
return noOpRequestObserver;
})
invocationOnMock -> {
testCaseStarted.set(null);
return noOpRequestObserver;
})
.when(firestore)
.streamRequest(
streamObserverCapture.capture(), Matchers.any(BidiStreamingCallable.class));
Expand Down Expand Up @@ -680,7 +679,7 @@ public void runTest() throws Throwable {
testCaseStarted.get();

for (ListenResponse response : testParameters.getResponsesList()) {
streamObserverCapture.getValue().onNext(response);
streamObserverCapture.getValue().onResponse(response);
}

testCaseFinished.get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;

import com.google.api.gax.rpc.ApiStreamObserver;
import com.google.api.gax.rpc.ResponseObserver;
import com.google.api.gax.rpc.ServerStreamingCallable;
import com.google.api.gax.rpc.UnaryCallable;
import com.google.cloud.Timestamp;
Expand Down Expand Up @@ -111,7 +111,7 @@ public class DocumentReferenceTest {

@Captor private ArgumentCaptor<BatchGetDocumentsRequest> getAllCapture;

@Captor private ArgumentCaptor<ApiStreamObserver> streamObserverCapture;
@Captor private ArgumentCaptor<ResponseObserver<CommitResponse>> streamObserverCapture;

private DocumentReference documentReference;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;

import com.google.api.gax.rpc.ApiStreamObserver;
import com.google.api.gax.rpc.ResponseObserver;
import com.google.api.gax.rpc.ServerStreamingCallable;
import com.google.api.gax.rpc.UnaryCallable;
import com.google.cloud.firestore.spi.v1.FirestoreRpc;
import com.google.firestore.v1.BatchGetDocumentsRequest;
import com.google.firestore.v1.CommitRequest;
import com.google.firestore.v1.CommitResponse;
import com.google.firestore.v1.ListCollectionIdsRequest;
import com.google.protobuf.Message;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
Expand All @@ -64,9 +64,7 @@ public class FirestoreTest {

@Captor private ArgumentCaptor<BatchGetDocumentsRequest> getAllCapture;

@Captor private ArgumentCaptor<ListCollectionIdsRequest> listCollectionIdsCapture;

@Captor private ArgumentCaptor<ApiStreamObserver> streamObserverCapture;
@Captor private ArgumentCaptor<ResponseObserver<Message>> streamObserverCapture;

@Captor private ArgumentCaptor<CommitRequest> commitCapture;

Expand Down
Loading