Skip to content

Commit 08bc846

Browse files
authored
fix: split connection pool based on credential (#2388)
1 parent 81a7d73 commit 08bc846

File tree

3 files changed

+133
-4
lines changed

3 files changed

+133
-4
lines changed

google-cloud-bigquerystorage/pom.xml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,11 @@
151151
<groupId>org.json</groupId>
152152
<artifactId>json</artifactId>
153153
</dependency>
154-
154+
<dependency>
155+
<groupId>com.google.auth</groupId>
156+
<artifactId>google-auth-library-credentials</artifactId>
157+
<version>1.22.0</version>
158+
</dependency>
155159

156160
<!-- Test dependencies -->
157161
<dependency>
@@ -200,6 +204,12 @@
200204
<artifactId>google-cloud-bigquery</artifactId>
201205
<scope>test</scope>
202206
</dependency>
207+
<dependency>
208+
<groupId>com.google.auth</groupId>
209+
<artifactId>google-auth-library-oauth2-http</artifactId>
210+
<version>1.22.0</version>
211+
<scope>test</scope>
212+
</dependency>
203213
<dependency>
204214
<groupId>com.google.code.findbugs</groupId>
205215
<artifactId>jsr305</artifactId>

google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/StreamWriter.java

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import com.google.api.gax.core.ExecutorProvider;
2222
import com.google.api.gax.retrying.RetrySettings;
2323
import com.google.api.gax.rpc.TransportChannelProvider;
24+
import com.google.auth.Credentials;
2425
import com.google.auto.value.AutoOneOf;
2526
import com.google.auto.value.AutoValue;
2627
import com.google.cloud.bigquery.storage.v1.AppendRowsRequest.MissingValueInterpretation;
@@ -46,6 +47,7 @@
4647
import java.util.logging.Logger;
4748
import java.util.regex.Matcher;
4849
import java.util.regex.Pattern;
50+
import javax.annotation.Nullable;
4951

5052
/**
5153
* A BigQuery Stream Writer that can be used to write data into BigQuery Table.
@@ -134,8 +136,11 @@ public static long getApiMaxRequestBytes() {
134136
abstract static class ConnectionPoolKey {
135137
abstract String location();
136138

137-
public static ConnectionPoolKey create(String location) {
138-
return new AutoValue_StreamWriter_ConnectionPoolKey(location);
139+
abstract int credentialsHashcode();
140+
141+
public static ConnectionPoolKey create(String location, @Nullable Credentials credentials) {
142+
return new AutoValue_StreamWriter_ConnectionPoolKey(
143+
location, credentials != null ? credentials.hashCode() : 0);
139144
}
140145
}
141146

@@ -273,14 +278,17 @@ private StreamWriter(Builder builder) throws IOException {
273278
}
274279
}
275280
this.location = location;
281+
CredentialsProvider credentialsProvider = client.getSettings().getCredentialsProvider();
276282
// Assume the connection in the same pool share the same client and trace id.
277283
// The first StreamWriter for a new stub will create the pool for the other
278284
// streams in the same region, meaning the per StreamWriter settings are no
279285
// longer working unless all streams share the same set of settings
280286
this.singleConnectionOrConnectionPool =
281287
SingleConnectionOrConnectionPool.ofConnectionPool(
282288
connectionPoolMap.computeIfAbsent(
283-
ConnectionPoolKey.create(location),
289+
ConnectionPoolKey.create(
290+
location,
291+
credentialsProvider != null ? credentialsProvider.getCredentials() : null),
284292
(key) -> {
285293
return new ConnectionWorkerPool(
286294
builder.maxInflightRequest,
@@ -581,6 +589,11 @@ ConnectionWorkerPool getTestOnlyConnectionWorkerPool() {
581589
return connectionWorkerPool;
582590
}
583591

592+
@VisibleForTesting
593+
Map<ConnectionPoolKey, ConnectionWorkerPool> getTestOnlyConnectionPoolMap() {
594+
return connectionPoolMap;
595+
}
596+
584597
// A method to clear the static connectio pool to avoid making pool visible to other tests.
585598
@VisibleForTesting
586599
static void clearConnectionPool() {

google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/StreamWriterTest.java

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import com.google.api.core.ApiFutureCallback;
2727
import com.google.api.core.ApiFutures;
2828
import com.google.api.gax.batching.FlowController;
29+
import com.google.api.gax.core.CredentialsProvider;
30+
import com.google.api.gax.core.FixedCredentialsProvider;
2931
import com.google.api.gax.core.GoogleCredentialsProvider;
3032
import com.google.api.gax.core.InstantiatingExecutorProvider;
3133
import com.google.api.gax.core.NoCredentialsProvider;
@@ -38,6 +40,7 @@
3840
import com.google.api.gax.rpc.InvalidArgumentException;
3941
import com.google.api.gax.rpc.StatusCode.Code;
4042
import com.google.api.gax.rpc.UnknownException;
43+
import com.google.auth.oauth2.UserCredentials;
4144
import com.google.cloud.bigquery.storage.test.Test.FooType;
4245
import com.google.cloud.bigquery.storage.v1.AppendRowsRequest.MissingValueInterpretation;
4346
import com.google.cloud.bigquery.storage.v1.ConnectionWorkerPool.Settings;
@@ -924,6 +927,109 @@ public void testProtoSchemaPiping_multiplexingCase() throws Exception {
924927
writer2.close();
925928
}
926929

930+
@Test
931+
public void testFixedCredentialProvider_nullProvider() throws Exception {
932+
// Use the shared connection mode.
933+
ConnectionWorkerPool.setOptions(
934+
Settings.builder().setMinConnectionsPerRegion(1).setMaxConnectionsPerRegion(1).build());
935+
ProtoSchema schema1 = createProtoSchema("Schema1");
936+
ProtoSchema schema2 = createProtoSchema("Schema2");
937+
CredentialsProvider credentialsProvider1 = FixedCredentialsProvider.create(null);
938+
CredentialsProvider credentialsProvider2 = FixedCredentialsProvider.create(null);
939+
StreamWriter writer1 =
940+
StreamWriter.newBuilder(TEST_STREAM_1, client)
941+
.setWriterSchema(schema1)
942+
.setLocation("US")
943+
.setEnableConnectionPool(true)
944+
.setMaxInflightRequests(1)
945+
.setCredentialsProvider(credentialsProvider1)
946+
.build();
947+
StreamWriter writer2 =
948+
StreamWriter.newBuilder(TEST_STREAM_2, client)
949+
.setWriterSchema(schema2)
950+
.setMaxInflightRequests(1)
951+
.setEnableConnectionPool(true)
952+
.setCredentialsProvider(credentialsProvider2)
953+
.setLocation("US")
954+
.build();
955+
// Null credential provided belong to the same connection pool.
956+
assertEquals(writer1.getTestOnlyConnectionPoolMap().size(), 1);
957+
}
958+
959+
@Test
960+
public void testFixedCredentialProvider_twoCredentialsSplitPool() throws Exception {
961+
// Use the shared connection mode.
962+
ConnectionWorkerPool.setOptions(
963+
Settings.builder().setMinConnectionsPerRegion(1).setMaxConnectionsPerRegion(1).build());
964+
ProtoSchema schema1 = createProtoSchema("Schema1");
965+
ProtoSchema schema2 = createProtoSchema("Schema2");
966+
UserCredentials userCredentials1 =
967+
UserCredentials.newBuilder()
968+
.setClientId("CLIENT_ID_1")
969+
.setClientSecret("CLIENT_SECRET_1")
970+
.setRefreshToken("REFRESH_TOKEN_1")
971+
.build();
972+
CredentialsProvider credentialsProvider1 = FixedCredentialsProvider.create(userCredentials1);
973+
UserCredentials userCredentials2 =
974+
UserCredentials.newBuilder()
975+
.setClientId("CLIENT_ID_2")
976+
.setClientSecret("CLIENT_SECRET_2")
977+
.setRefreshToken("REFRESH_TOKEN_2")
978+
.build();
979+
CredentialsProvider credentialsProvider2 = FixedCredentialsProvider.create(userCredentials2);
980+
StreamWriter writer1 =
981+
StreamWriter.newBuilder(TEST_STREAM_1)
982+
.setWriterSchema(schema1)
983+
.setLocation("US")
984+
.setEnableConnectionPool(true)
985+
.setMaxInflightRequests(1)
986+
.setCredentialsProvider(credentialsProvider1)
987+
.build();
988+
StreamWriter writer2 =
989+
StreamWriter.newBuilder(TEST_STREAM_2)
990+
.setWriterSchema(schema2)
991+
.setMaxInflightRequests(1)
992+
.setEnableConnectionPool(true)
993+
.setLocation("US")
994+
.setCredentialsProvider(credentialsProvider2)
995+
.build();
996+
assertEquals(writer1.getTestOnlyConnectionPoolMap().size(), 2);
997+
}
998+
999+
@Test
1000+
public void testFixedCredentialProvider_twoProviderSameCredentialSharePool() throws Exception {
1001+
// Use the shared connection mode.
1002+
ConnectionWorkerPool.setOptions(
1003+
Settings.builder().setMinConnectionsPerRegion(1).setMaxConnectionsPerRegion(1).build());
1004+
ProtoSchema schema1 = createProtoSchema("Schema1");
1005+
ProtoSchema schema2 = createProtoSchema("Schema2");
1006+
UserCredentials userCredentials =
1007+
UserCredentials.newBuilder()
1008+
.setClientId("CLIENT_ID_1")
1009+
.setClientSecret("CLIENT_SECRET_1")
1010+
.setRefreshToken("REFRESH_TOKEN_1")
1011+
.build();
1012+
CredentialsProvider credentialsProvider1 = FixedCredentialsProvider.create(userCredentials);
1013+
CredentialsProvider credentialsProvider2 = FixedCredentialsProvider.create(userCredentials);
1014+
StreamWriter writer1 =
1015+
StreamWriter.newBuilder(TEST_STREAM_1)
1016+
.setWriterSchema(schema1)
1017+
.setLocation("US")
1018+
.setEnableConnectionPool(true)
1019+
.setMaxInflightRequests(1)
1020+
.setCredentialsProvider(credentialsProvider1)
1021+
.build();
1022+
StreamWriter writer2 =
1023+
StreamWriter.newBuilder(TEST_STREAM_2)
1024+
.setWriterSchema(schema2)
1025+
.setMaxInflightRequests(1)
1026+
.setEnableConnectionPool(true)
1027+
.setLocation("US")
1028+
.setCredentialsProvider(credentialsProvider2)
1029+
.build();
1030+
assertEquals(writer1.getTestOnlyConnectionPoolMap().size(), 1);
1031+
}
1032+
9271033
@Test
9281034
public void testDefaultValueInterpretation_multiplexingCase() throws Exception {
9291035
// Use the shared connection mode.

0 commit comments

Comments
 (0)