Skip to content

Commit 5789e4b

Browse files
authored
feat: Add retry logic to COUNT queries (#1062)
1 parent 03de01a commit 5789e4b

File tree

3 files changed

+219
-43
lines changed

3 files changed

+219
-43
lines changed

google-cloud-firestore/src/main/java/com/google/cloud/firestore/AggregateQuery.java

Lines changed: 75 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,17 @@
2121
import com.google.api.core.SettableApiFuture;
2222
import com.google.api.gax.rpc.ResponseObserver;
2323
import com.google.api.gax.rpc.ServerStreamingCallable;
24+
import com.google.api.gax.rpc.StatusCode;
2425
import com.google.api.gax.rpc.StreamController;
2526
import com.google.cloud.Timestamp;
27+
import com.google.cloud.firestore.v1.FirestoreSettings;
2628
import com.google.firestore.v1.RunAggregationQueryRequest;
2729
import com.google.firestore.v1.RunAggregationQueryResponse;
2830
import com.google.firestore.v1.RunQueryRequest;
2931
import com.google.firestore.v1.StructuredAggregationQuery;
3032
import com.google.firestore.v1.Value;
3133
import com.google.protobuf.ByteString;
34+
import java.util.Set;
3235
import java.util.concurrent.atomic.AtomicBoolean;
3336
import javax.annotation.Nonnull;
3437
import javax.annotation.Nullable;
@@ -68,25 +71,68 @@ public ApiFuture<AggregateQuerySnapshot> get() {
6871

6972
@Nonnull
7073
ApiFuture<AggregateQuerySnapshot> get(@Nullable final ByteString transactionId) {
71-
RunAggregationQueryRequest request = toProto(transactionId);
72-
AggregateQueryResponseObserver responseObserver = new AggregateQueryResponseObserver();
74+
AggregateQueryResponseDeliverer responseDeliverer =
75+
new AggregateQueryResponseDeliverer(
76+
transactionId, /* startTimeNanos= */ query.rpcContext.getClock().nanoTime());
77+
runQuery(responseDeliverer);
78+
return responseDeliverer.getFuture();
79+
}
80+
81+
private void runQuery(AggregateQueryResponseDeliverer responseDeliverer) {
82+
RunAggregationQueryRequest request = toProto(responseDeliverer.getTransactionId());
83+
AggregateQueryResponseObserver responseObserver =
84+
new AggregateQueryResponseObserver(responseDeliverer);
7385
ServerStreamingCallable<RunAggregationQueryRequest, RunAggregationQueryResponse> callable =
7486
query.rpcContext.getClient().runAggregationQueryCallable();
75-
7687
query.rpcContext.streamRequest(request, responseObserver, callable);
88+
}
89+
90+
private final class AggregateQueryResponseDeliverer {
91+
92+
@Nullable private final ByteString transactionId;
93+
private final long startTimeNanos;
94+
private final SettableApiFuture<AggregateQuerySnapshot> future = SettableApiFuture.create();
95+
private final AtomicBoolean isFutureCompleted = new AtomicBoolean(false);
96+
97+
AggregateQueryResponseDeliverer(@Nullable ByteString transactionId, long startTimeNanos) {
98+
this.transactionId = transactionId;
99+
this.startTimeNanos = startTimeNanos;
100+
}
101+
102+
ApiFuture<AggregateQuerySnapshot> getFuture() {
103+
return future;
104+
}
105+
106+
@Nullable
107+
ByteString getTransactionId() {
108+
return transactionId;
109+
}
110+
111+
long getStartTimeNanos() {
112+
return startTimeNanos;
113+
}
114+
115+
void deliverResult(long count, Timestamp readTime) {
116+
if (isFutureCompleted.compareAndSet(false, true)) {
117+
future.set(new AggregateQuerySnapshot(AggregateQuery.this, readTime, count));
118+
}
119+
}
77120

78-
return responseObserver.getFuture();
121+
void deliverError(Throwable throwable) {
122+
if (isFutureCompleted.compareAndSet(false, true)) {
123+
future.setException(throwable);
124+
}
125+
}
79126
}
80127

81128
private final class AggregateQueryResponseObserver
82129
implements ResponseObserver<RunAggregationQueryResponse> {
83130

84-
private final SettableApiFuture<AggregateQuerySnapshot> future = SettableApiFuture.create();
85-
private final AtomicBoolean isFutureNotified = new AtomicBoolean(false);
131+
private final AggregateQueryResponseDeliverer responseDeliverer;
86132
private StreamController streamController;
87133

88-
SettableApiFuture<AggregateQuerySnapshot> getFuture() {
89-
return future;
134+
AggregateQueryResponseObserver(AggregateQueryResponseDeliverer responseDeliverer) {
135+
this.responseDeliverer = responseDeliverer;
90136
}
91137

92138
@Override
@@ -96,14 +142,10 @@ public void onStart(StreamController streamController) {
96142

97143
@Override
98144
public void onResponse(RunAggregationQueryResponse response) {
99-
// Ignore subsequent response messages. The RunAggregationQuery RPC returns a stream of
100-
// responses (rather than just a single response); however, only the first response of the
101-
// stream is actually used. Any more responses are technically errors, but since the Future
102-
// will have already been notified, we just drop any unexpected responses.
103-
if (!isFutureNotified.compareAndSet(false, true)) {
104-
return;
105-
}
145+
// Close the stream to avoid it dangling, since we're not expecting any more responses.
146+
streamController.cancel();
106147

148+
// Extract the count and read time from the RunAggregationQueryResponse.
107149
Timestamp readTime = Timestamp.fromProto(response.getReadTime());
108150
Value value = response.getResult().getAggregateFieldsMap().get(ALIAS_COUNT);
109151
if (value == null) {
@@ -118,19 +160,30 @@ public void onResponse(RunAggregationQueryResponse response) {
118160
}
119161
long count = value.getIntegerValue();
120162

121-
future.set(new AggregateQuerySnapshot(AggregateQuery.this, readTime, count));
122-
123-
// Close the stream to avoid it dangling, since we're not expecting any more responses.
124-
streamController.cancel();
163+
// Deliver the result; even though the `RunAggregationQuery` RPC is a "streaming" RPC, meaning
164+
// that `onResponse()` can be called multiple times, it _should_ only be called once for count
165+
// queries. But even if it is called more than once, `responseDeliverer` will drop superfluous
166+
// results.
167+
responseDeliverer.deliverResult(count, readTime);
125168
}
126169

127170
@Override
128171
public void onError(Throwable throwable) {
129-
if (!isFutureNotified.compareAndSet(false, true)) {
130-
return;
172+
if (shouldRetry(throwable)) {
173+
runQuery(responseDeliverer);
174+
} else {
175+
responseDeliverer.deliverError(throwable);
131176
}
177+
}
132178

133-
future.setException(throwable);
179+
private boolean shouldRetry(Throwable throwable) {
180+
Set<StatusCode.Code> retryableCodes =
181+
FirestoreSettings.newBuilder().runAggregationQuerySettings().getRetryableCodes();
182+
return query.shouldRetryQuery(
183+
throwable,
184+
responseDeliverer.getTransactionId(),
185+
responseDeliverer.getStartTimeNanos(),
186+
retryableCodes);
134187
}
135188

136189
@Override

google-cloud-firestore/src/main/java/com/google/cloud/firestore/Query.java

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1672,27 +1672,15 @@ public void onComplete() {
16721672
}
16731673

16741674
boolean shouldRetry(DocumentSnapshot lastDocument, Throwable t) {
1675-
if (transactionId != null) {
1676-
// Transactional queries are retried via the transaction runner.
1677-
return false;
1678-
}
1679-
16801675
if (lastDocument == null) {
16811676
// Only retry if we have received a single result. Retries for RPCs with initial
16821677
// failure are handled by Google Gax, which also implements backoff.
16831678
return false;
16841679
}
16851680

1686-
if (!isRetryableError(t)) {
1687-
return false;
1688-
}
1689-
1690-
if (rpcContext.getTotalRequestTimeout().isZero()) {
1691-
return true;
1692-
}
1693-
1694-
Duration duration = Duration.ofNanos(rpcContext.getClock().nanoTime() - startTimeNanos);
1695-
return duration.compareTo(rpcContext.getTotalRequestTimeout()) < 0;
1681+
Set<StatusCode.Code> retryableCodes =
1682+
FirestoreSettings.newBuilder().runQuerySettings().getRetryableCodes();
1683+
return shouldRetryQuery(t, transactionId, startTimeNanos, retryableCodes);
16961684
}
16971685
};
16981686

@@ -1831,21 +1819,42 @@ private <T> ImmutableList<T> append(ImmutableList<T> existingList, T newElement)
18311819
}
18321820

18331821
/** Verifies whether the given exception is retryable based on the RunQuery configuration. */
1834-
private boolean isRetryableError(Throwable throwable) {
1822+
private boolean isRetryableError(Throwable throwable, Set<StatusCode.Code> retryableCodes) {
18351823
if (!(throwable instanceof FirestoreException)) {
18361824
return false;
18371825
}
1838-
Set<StatusCode.Code> codes =
1839-
FirestoreSettings.newBuilder().runQuerySettings().getRetryableCodes();
18401826
Status status = ((FirestoreException) throwable).getStatus();
1841-
for (StatusCode.Code code : codes) {
1827+
for (StatusCode.Code code : retryableCodes) {
18421828
if (code.equals(StatusCode.Code.valueOf(status.getCode().name()))) {
18431829
return true;
18441830
}
18451831
}
18461832
return false;
18471833
}
18481834

1835+
/** Returns whether a query that failed in the given scenario should be retried. */
1836+
boolean shouldRetryQuery(
1837+
Throwable throwable,
1838+
@Nullable ByteString transactionId,
1839+
long startTimeNanos,
1840+
Set<StatusCode.Code> retryableCodes) {
1841+
if (transactionId != null) {
1842+
// Transactional queries are retried via the transaction runner.
1843+
return false;
1844+
}
1845+
1846+
if (!isRetryableError(throwable, retryableCodes)) {
1847+
return false;
1848+
}
1849+
1850+
if (rpcContext.getTotalRequestTimeout().isZero()) {
1851+
return true;
1852+
}
1853+
1854+
Duration duration = Duration.ofNanos(rpcContext.getClock().nanoTime() - startTimeNanos);
1855+
return duration.compareTo(rpcContext.getTotalRequestTimeout()) < 0;
1856+
}
1857+
18491858
/**
18501859
* Returns a query that counts the documents in the result set of this query.
18511860
*

google-cloud-firestore/src/test/java/com/google/cloud/firestore/QueryCountTest.java

Lines changed: 116 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
import static org.junit.Assert.assertThrows;
3131
import static org.mockito.Mockito.doAnswer;
3232
import static org.mockito.Mockito.doReturn;
33+
import static org.mockito.Mockito.mock;
3334

35+
import com.google.api.core.ApiClock;
3436
import com.google.api.core.ApiFuture;
3537
import com.google.api.gax.rpc.ResponseObserver;
3638
import com.google.api.gax.rpc.ServerStreamingCallable;
@@ -39,14 +41,15 @@
3941
import com.google.firestore.v1.RunAggregationQueryRequest;
4042
import com.google.firestore.v1.RunAggregationQueryResponse;
4143
import com.google.firestore.v1.StructuredQuery;
44+
import io.grpc.Status;
4245
import java.util.concurrent.ExecutionException;
46+
import java.util.concurrent.TimeUnit;
4347
import org.junit.Before;
4448
import org.junit.Test;
4549
import org.junit.runner.RunWith;
4650
import org.mockito.ArgumentCaptor;
4751
import org.mockito.Captor;
4852
import org.mockito.Matchers;
49-
import org.mockito.Mockito;
5053
import org.mockito.Spy;
5154
import org.mockito.runners.MockitoJUnitRunner;
5255
import org.threeten.bp.Duration;
@@ -58,7 +61,7 @@ public class QueryCountTest {
5861
private final FirestoreImpl firestoreMock =
5962
new FirestoreImpl(
6063
FirestoreOptions.newBuilder().setProjectId("test-project").build(),
61-
Mockito.mock(FirestoreRpc.class));
64+
mock(FirestoreRpc.class));
6265

6366
@Captor private ArgumentCaptor<RunAggregationQueryRequest> runAggregationQuery;
6467

@@ -211,6 +214,117 @@ public void aggregateQuerySnapshotGetQueryShouldReturnCorrectValue() throws Exce
211214

212215
AggregateQuery aggregateQuery = query.count();
213216
AggregateQuerySnapshot snapshot = aggregateQuery.get().get();
217+
214218
assertThat(snapshot.getQuery()).isSameInstanceAs(aggregateQuery);
215219
}
220+
221+
@Test
222+
public void shouldNotRetryIfExceptionIsNotFirestoreException() {
223+
doAnswer(aggregationQueryResponse(new NotFirestoreException()))
224+
.doAnswer(aggregationQueryResponse())
225+
.when(firestoreMock)
226+
.streamRequest(
227+
runAggregationQuery.capture(),
228+
streamObserverCapture.capture(),
229+
Matchers.<ServerStreamingCallable>any());
230+
231+
ApiFuture<AggregateQuerySnapshot> future = query.count().get();
232+
233+
assertThrows(ExecutionException.class, future::get);
234+
}
235+
236+
@Test
237+
public void shouldRetryIfExceptionIsFirestoreExceptionWithRetryableStatus() throws Exception {
238+
doAnswer(aggregationQueryResponse(new FirestoreException("reason", Status.INTERNAL)))
239+
.doAnswer(aggregationQueryResponse(42))
240+
.when(firestoreMock)
241+
.streamRequest(
242+
runAggregationQuery.capture(),
243+
streamObserverCapture.capture(),
244+
Matchers.<ServerStreamingCallable>any());
245+
246+
ApiFuture<AggregateQuerySnapshot> future = query.count().get();
247+
AggregateQuerySnapshot snapshot = future.get();
248+
249+
assertThat(snapshot.getCount()).isEqualTo(42);
250+
}
251+
252+
@Test
253+
public void shouldNotRetryIfExceptionIsFirestoreExceptionWithNonRetryableStatus() {
254+
doReturn(Duration.ZERO).when(firestoreMock).getTotalRequestTimeout();
255+
doAnswer(aggregationQueryResponse(new FirestoreException("reason", Status.INVALID_ARGUMENT)))
256+
.doAnswer(aggregationQueryResponse())
257+
.when(firestoreMock)
258+
.streamRequest(
259+
runAggregationQuery.capture(),
260+
streamObserverCapture.capture(),
261+
Matchers.<ServerStreamingCallable>any());
262+
263+
ApiFuture<AggregateQuerySnapshot> future = query.count().get();
264+
265+
assertThrows(ExecutionException.class, future::get);
266+
}
267+
268+
@Test
269+
public void
270+
shouldRetryIfExceptionIsFirestoreExceptionWithRetryableStatusWithInfiniteTimeoutWindow()
271+
throws Exception {
272+
doReturn(Duration.ZERO).when(firestoreMock).getTotalRequestTimeout();
273+
doAnswer(aggregationQueryResponse(new FirestoreException("reason", Status.INTERNAL)))
274+
.doAnswer(aggregationQueryResponse(42))
275+
.when(firestoreMock)
276+
.streamRequest(
277+
runAggregationQuery.capture(),
278+
streamObserverCapture.capture(),
279+
Matchers.<ServerStreamingCallable>any());
280+
281+
ApiFuture<AggregateQuerySnapshot> future = query.count().get();
282+
AggregateQuerySnapshot snapshot = future.get();
283+
284+
assertThat(snapshot.getCount()).isEqualTo(42);
285+
}
286+
287+
@Test
288+
public void shouldRetryIfExceptionIsFirestoreExceptionWithRetryableStatusWithinTimeoutWindow()
289+
throws Exception {
290+
doReturn(Duration.ofDays(999)).when(firestoreMock).getTotalRequestTimeout();
291+
doAnswer(aggregationQueryResponse(new FirestoreException("reason", Status.INTERNAL)))
292+
.doAnswer(aggregationQueryResponse(42))
293+
.when(firestoreMock)
294+
.streamRequest(
295+
runAggregationQuery.capture(),
296+
streamObserverCapture.capture(),
297+
Matchers.<ServerStreamingCallable>any());
298+
299+
ApiFuture<AggregateQuerySnapshot> future = query.count().get();
300+
AggregateQuerySnapshot snapshot = future.get();
301+
302+
assertThat(snapshot.getCount()).isEqualTo(42);
303+
}
304+
305+
@Test
306+
public void
307+
shouldNotRetryIfExceptionIsFirestoreExceptionWithRetryableStatusBeyondTimeoutWindow() {
308+
ApiClock clockMock = mock(ApiClock.class);
309+
doReturn(clockMock).when(firestoreMock).getClock();
310+
doReturn(TimeUnit.SECONDS.toNanos(10))
311+
.doReturn(TimeUnit.SECONDS.toNanos(20))
312+
.doReturn(TimeUnit.SECONDS.toNanos(30))
313+
.when(clockMock)
314+
.nanoTime();
315+
doReturn(Duration.ofSeconds(5)).when(firestoreMock).getTotalRequestTimeout();
316+
doAnswer(aggregationQueryResponse(new FirestoreException("reason", Status.INTERNAL)))
317+
.doAnswer(aggregationQueryResponse(42))
318+
.when(firestoreMock)
319+
.streamRequest(
320+
runAggregationQuery.capture(),
321+
streamObserverCapture.capture(),
322+
Matchers.<ServerStreamingCallable>any());
323+
324+
ApiFuture<AggregateQuerySnapshot> future = query.count().get();
325+
326+
assertThrows(ExecutionException.class, future::get);
327+
}
328+
329+
private static final class NotFirestoreException extends Exception {}
216330
}

0 commit comments

Comments
 (0)