Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static com.google.common.base.Preconditions.checkState;

import com.google.api.gax.grpc.GrpcStatusCode;
import com.google.api.gax.rpc.DeadlineExceededException;
import com.google.api.gax.rpc.ServerStream;
import com.google.api.gax.rpc.UnavailableException;
import com.google.cloud.spanner.SessionImpl.SessionTransaction;
Expand Down Expand Up @@ -83,7 +84,7 @@ long executeStreamingPartitionedUpdate(final Statement statement, Duration timeo
boolean foundStats = false;
long updateCount = 0L;
Duration remainingTimeout = timeout;
Stopwatch stopWatch = Stopwatch.createStarted();
Stopwatch stopWatch = createStopwatchStarted();
try {
// Loop to catch AbortedExceptions.
while (true) {
Expand All @@ -107,6 +108,11 @@ long executeStreamingPartitionedUpdate(final Statement statement, Duration timeo
while (true) {
remainingTimeout =
remainingTimeout.minus(stopWatch.elapsed(TimeUnit.MILLISECONDS), ChronoUnit.MILLIS);
if (remainingTimeout.isNegative() || remainingTimeout.isZero()) {
// The total deadline has been exceeded while retrying.
throw new DeadlineExceededException(
null, GrpcStatusCode.of(Code.DEADLINE_EXCEEDED), false);
}
try {
builder.setResumeToken(resumeToken);
ServerStream<PartialResultSet> stream =
Expand Down Expand Up @@ -157,6 +163,10 @@ long executeStreamingPartitionedUpdate(final Statement statement, Duration timeo
}
}

Stopwatch createStopwatchStarted() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would it be hard to inject the Stopwatch instead? Not a big deal, just wondering if we could use composition here. Disregard if it would be too hard to inject.

return Stopwatch.createStarted();
}

@Override
public void invalidate() {
isValid = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ public GapicSpannerRpc(final SpannerOptions options) {

// Set a keepalive time of 120 seconds to help long running
// commit GRPC calls succeed
.setKeepAliveTime(Duration.ofSeconds(GRPC_KEEPALIVE_SECONDS * 1000))
.setKeepAliveTime(Duration.ofSeconds(GRPC_KEEPALIVE_SECONDS))

// Then check if SpannerOptions provides an InterceptorProvider. Create a default
// SpannerInterceptorProvider if none is provided
Expand Down Expand Up @@ -365,6 +365,7 @@ public GapicSpannerRpc(final SpannerOptions options) {
.setStreamWatchdogProvider(watchdogProvider)
.executeSqlSettings()
.setRetrySettings(partitionedDmlRetrySettings);
pdmlSettings.executeStreamingSqlSettings().setRetrySettings(partitionedDmlRetrySettings);
// The stream watchdog will by default only check for a timeout every 10 seconds, so if the
// timeout is less than 10 seconds, it would be ignored for the first 10 seconds unless we
// also change the StreamWatchdogCheckInterval.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
/*
* Copyright 2020 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.spanner;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyMap;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.google.api.gax.grpc.GrpcStatusCode;
import com.google.api.gax.rpc.AbortedException;
import com.google.api.gax.rpc.ServerStream;
import com.google.api.gax.rpc.UnavailableException;
import com.google.cloud.spanner.spi.v1.SpannerRpc;
import com.google.common.base.Stopwatch;
import com.google.common.base.Ticker;
import com.google.common.collect.ImmutableList;
import com.google.protobuf.ByteString;
import com.google.spanner.v1.BeginTransactionRequest;
import com.google.spanner.v1.ExecuteSqlRequest;
import com.google.spanner.v1.ExecuteSqlRequest.QueryMode;
import com.google.spanner.v1.PartialResultSet;
import com.google.spanner.v1.ResultSetStats;
import com.google.spanner.v1.Transaction;
import com.google.spanner.v1.TransactionSelector;
import io.grpc.Status.Code;
import java.util.Collections;
import java.util.Iterator;
import java.util.concurrent.TimeUnit;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.threeten.bp.Duration;

@SuppressWarnings("unchecked")
@RunWith(JUnit4.class)
public class PartitionedDmlTransactionTest {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice tests!


@Mock private SpannerRpc rpc;

@Mock private SessionImpl session;

private final String sessionId = "projects/p/instances/i/databases/d/sessions/s";
private final ByteString txId = ByteString.copyFromUtf8("tx");
private final ByteString resumeToken = ByteString.copyFromUtf8("resume");
private final String sql = "UPDATE FOO SET BAR=1 WHERE TRUE";
private final ExecuteSqlRequest executeRequestWithoutResumeToken =
ExecuteSqlRequest.newBuilder()
.setQueryMode(QueryMode.NORMAL)
.setSession(sessionId)
.setSql(sql)
.setTransaction(TransactionSelector.newBuilder().setId(txId))
.build();
private final ExecuteSqlRequest executeRequestWithResumeToken =
executeRequestWithoutResumeToken.toBuilder().setResumeToken(resumeToken).build();

@Before
public void setup() {
MockitoAnnotations.initMocks(this);
when(session.getName()).thenReturn(sessionId);
when(session.getOptions()).thenReturn(Collections.EMPTY_MAP);
when(rpc.beginTransaction(any(BeginTransactionRequest.class), anyMap()))
.thenReturn(Transaction.newBuilder().setId(txId).build());
}

@Test
public void testExecuteStreamingPartitionedUpdate() {
ResultSetStats stats = ResultSetStats.newBuilder().setRowCountLowerBound(1000L).build();
PartialResultSet p1 = PartialResultSet.newBuilder().setResumeToken(resumeToken).build();
PartialResultSet p2 = PartialResultSet.newBuilder().setStats(stats).build();
ServerStream<PartialResultSet> stream = mock(ServerStream.class);
when(stream.iterator()).thenReturn(ImmutableList.of(p1, p2).iterator());
when(rpc.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class)))
.thenReturn(stream);

PartitionedDMLTransaction tx = new PartitionedDMLTransaction(session, rpc);
long count = tx.executeStreamingPartitionedUpdate(Statement.of(sql), Duration.ofMinutes(10));
assertThat(count).isEqualTo(1000L);
verify(rpc).beginTransaction(any(BeginTransactionRequest.class), anyMap());
verify(rpc)
.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class));
}

@Test
public void testExecuteStreamingPartitionedUpdateAborted() {
ResultSetStats stats = ResultSetStats.newBuilder().setRowCountLowerBound(1000L).build();
PartialResultSet p1 = PartialResultSet.newBuilder().setResumeToken(resumeToken).build();
PartialResultSet p2 = PartialResultSet.newBuilder().setStats(stats).build();
ServerStream<PartialResultSet> stream1 = mock(ServerStream.class);
Iterator<PartialResultSet> iterator = mock(Iterator.class);
when(iterator.hasNext()).thenReturn(true, true, false);
when(iterator.next())
.thenReturn(p1)
.thenThrow(
new AbortedException(
"transaction aborted", null, GrpcStatusCode.of(Code.ABORTED), true));
when(stream1.iterator()).thenReturn(iterator);
ServerStream<PartialResultSet> stream2 = mock(ServerStream.class);
when(stream2.iterator()).thenReturn(ImmutableList.of(p1, p2).iterator());
when(rpc.executeStreamingPartitionedDml(
any(ExecuteSqlRequest.class), anyMap(), any(Duration.class)))
.thenReturn(stream1, stream2);

PartitionedDMLTransaction tx = new PartitionedDMLTransaction(session, rpc);
long count = tx.executeStreamingPartitionedUpdate(Statement.of(sql), Duration.ofMinutes(10));
assertThat(count).isEqualTo(1000L);
verify(rpc, times(2)).beginTransaction(any(BeginTransactionRequest.class), anyMap());
verify(rpc, times(2))
.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class));
}

@Test
public void testExecuteStreamingPartitionedUpdateUnavailable() {
ResultSetStats stats = ResultSetStats.newBuilder().setRowCountLowerBound(1000L).build();
PartialResultSet p1 = PartialResultSet.newBuilder().setResumeToken(resumeToken).build();
PartialResultSet p2 = PartialResultSet.newBuilder().setStats(stats).build();
ServerStream<PartialResultSet> stream1 = mock(ServerStream.class);
Iterator<PartialResultSet> iterator = mock(Iterator.class);
when(iterator.hasNext()).thenReturn(true, true, false);
when(iterator.next())
.thenReturn(p1)
.thenThrow(
new UnavailableException(
"temporary unavailable", null, GrpcStatusCode.of(Code.UNAVAILABLE), true));
when(stream1.iterator()).thenReturn(iterator);
ServerStream<PartialResultSet> stream2 = mock(ServerStream.class);
when(stream2.iterator()).thenReturn(ImmutableList.of(p1, p2).iterator());
when(rpc.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class)))
.thenReturn(stream1);
when(rpc.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithResumeToken), anyMap(), any(Duration.class)))
.thenReturn(stream2);

PartitionedDMLTransaction tx = new PartitionedDMLTransaction(session, rpc);
long count = tx.executeStreamingPartitionedUpdate(Statement.of(sql), Duration.ofMinutes(10));
assertThat(count).isEqualTo(1000L);
verify(rpc).beginTransaction(any(BeginTransactionRequest.class), anyMap());
verify(rpc)
.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class));
verify(rpc)
.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithResumeToken), anyMap(), any(Duration.class));
}

@Test
public void testExecuteStreamingPartitionedUpdateUnavailableAndThenDeadlineExceeded() {
PartialResultSet p1 = PartialResultSet.newBuilder().setResumeToken(resumeToken).build();
ServerStream<PartialResultSet> stream1 = mock(ServerStream.class);
Iterator<PartialResultSet> iterator = mock(Iterator.class);
when(iterator.hasNext()).thenReturn(true, true, false);
when(iterator.next())
.thenReturn(p1)
.thenThrow(
new UnavailableException(
"temporary unavailable", null, GrpcStatusCode.of(Code.UNAVAILABLE), true));
when(stream1.iterator()).thenReturn(iterator);
when(rpc.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class)))
.thenReturn(stream1);

PartitionedDMLTransaction tx =
new PartitionedDMLTransaction(session, rpc) {
@Override
Stopwatch createStopwatchStarted() {
Ticker ticker = mock(Ticker.class);
when(ticker.read())
.thenReturn(0L, 1L, TimeUnit.NANOSECONDS.convert(10L, TimeUnit.MINUTES));
return Stopwatch.createStarted(ticker);
}
};
try {
tx.executeStreamingPartitionedUpdate(Statement.of(sql), Duration.ofMinutes(10));
fail("missing expected DEADLINE_EXCEEDED exception");
} catch (SpannerException e) {
assertThat(e.getErrorCode()).isEqualTo(ErrorCode.DEADLINE_EXCEEDED);
verify(rpc).beginTransaction(any(BeginTransactionRequest.class), anyMap());
verify(rpc)
.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class));
}
}

@Test
public void testExecuteStreamingPartitionedUpdateAbortedAndThenDeadlineExceeded() {
PartialResultSet p1 = PartialResultSet.newBuilder().setResumeToken(resumeToken).build();
ServerStream<PartialResultSet> stream1 = mock(ServerStream.class);
Iterator<PartialResultSet> iterator = mock(Iterator.class);
when(iterator.hasNext()).thenReturn(true, true, false);
when(iterator.next())
.thenReturn(p1)
.thenThrow(
new AbortedException(
"transaction aborted", null, GrpcStatusCode.of(Code.ABORTED), true));
when(stream1.iterator()).thenReturn(iterator);
when(rpc.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class)))
.thenReturn(stream1);

PartitionedDMLTransaction tx =
new PartitionedDMLTransaction(session, rpc) {
@Override
Stopwatch createStopwatchStarted() {
Ticker ticker = mock(Ticker.class);
when(ticker.read())
.thenReturn(0L, 1L, TimeUnit.NANOSECONDS.convert(10L, TimeUnit.MINUTES));
return Stopwatch.createStarted(ticker);
}
};
try {
tx.executeStreamingPartitionedUpdate(Statement.of(sql), Duration.ofMinutes(10));
fail("missing expected DEADLINE_EXCEEDED exception");
} catch (SpannerException e) {
assertThat(e.getErrorCode()).isEqualTo(ErrorCode.DEADLINE_EXCEEDED);
verify(rpc, times(2)).beginTransaction(any(BeginTransactionRequest.class), anyMap());
verify(rpc)
.executeStreamingPartitionedDml(
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class));
}
}
}