1616
1717package com .google .cloud .spanner .spi .v1 ;
1818
19+ import static com .google .common .truth .Truth .assertThat ;
1920import static org .hamcrest .CoreMatchers .equalTo ;
2021import static org .hamcrest .CoreMatchers .is ;
2122import static org .hamcrest .MatcherAssert .assertThat ;
2223
2324import com .google .api .core .ApiFunction ;
24- import com .google .cloud .NoCredentials ;
25+ import com .google .auth .oauth2 .AccessToken ;
26+ import com .google .auth .oauth2 .OAuth2Credentials ;
2527import com .google .cloud .spanner .DatabaseAdminClient ;
2628import com .google .cloud .spanner .DatabaseClient ;
2729import com .google .cloud .spanner .DatabaseId ;
3133import com .google .cloud .spanner .ResultSet ;
3234import com .google .cloud .spanner .Spanner ;
3335import com .google .cloud .spanner .SpannerOptions ;
36+ import com .google .cloud .spanner .SpannerOptions .CallCredentialsProvider ;
3437import com .google .cloud .spanner .Statement ;
3538import com .google .cloud .spanner .admin .database .v1 .MockDatabaseAdminImpl ;
3639import com .google .cloud .spanner .admin .instance .v1 .MockInstanceAdminImpl ;
40+ import com .google .cloud .spanner .spi .v1 .SpannerRpc .Option ;
3741import com .google .common .base .Stopwatch ;
3842import com .google .protobuf .ListValue ;
3943import com .google .spanner .admin .database .v1 .Database ;
4549import com .google .spanner .v1 .StructType ;
4650import com .google .spanner .v1 .StructType .Field ;
4751import com .google .spanner .v1 .TypeCode ;
52+ import io .grpc .CallCredentials ;
53+ import io .grpc .Context ;
54+ import io .grpc .Contexts ;
4855import io .grpc .ManagedChannelBuilder ;
56+ import io .grpc .Metadata ;
57+ import io .grpc .Metadata .Key ;
4958import io .grpc .Server ;
59+ import io .grpc .ServerCall ;
60+ import io .grpc .ServerCallHandler ;
61+ import io .grpc .ServerInterceptor ;
62+ import io .grpc .auth .MoreCallCredentials ;
5063import io .grpc .netty .shaded .io .grpc .netty .NettyServerBuilder ;
5164import java .io .IOException ;
5265import java .net .InetSocketAddress ;
5366import java .util .ArrayList ;
67+ import java .util .HashMap ;
5468import java .util .List ;
69+ import java .util .Map ;
5570import java .util .concurrent .TimeUnit ;
5671import java .util .regex .Pattern ;
5772import org .junit .After ;
@@ -91,11 +106,27 @@ public class GapicSpannerRpcTest {
91106 .build ())
92107 .setMetadata (SELECT1AND2_METADATA )
93108 .build ();
109+ private static final String STATIC_OAUTH_TOKEN = "STATIC_TEST_OAUTH_TOKEN" ;
110+ private static final String VARIABLE_OAUTH_TOKEN = "VARIABLE_TEST_OAUTH_TOKEN" ;
111+ private static final OAuth2Credentials STATIC_CREDENTIALS =
112+ OAuth2Credentials .create (
113+ new AccessToken (
114+ STATIC_OAUTH_TOKEN ,
115+ new java .util .Date (
116+ System .currentTimeMillis () + TimeUnit .MILLISECONDS .convert (1L , TimeUnit .DAYS ))));
117+ private static final OAuth2Credentials VARIABLE_CREDENTIALS =
118+ OAuth2Credentials .create (
119+ new AccessToken (
120+ VARIABLE_OAUTH_TOKEN ,
121+ new java .util .Date (
122+ System .currentTimeMillis () + TimeUnit .MILLISECONDS .convert (1L , TimeUnit .DAYS ))));
123+
94124 private MockSpannerServiceImpl mockSpanner ;
95125 private MockInstanceAdminImpl mockInstanceAdmin ;
96126 private MockDatabaseAdminImpl mockDatabaseAdmin ;
97127 private Server server ;
98128 private InetSocketAddress address ;
129+ private final Map <SpannerRpc .Option , Object > optionsMap = new HashMap <>();
99130
100131 @ Before
101132 public void startServer () throws IOException {
@@ -111,8 +142,24 @@ public void startServer() throws IOException {
111142 .addService (mockSpanner )
112143 .addService (mockInstanceAdmin )
113144 .addService (mockDatabaseAdmin )
145+ // Add a server interceptor that will check that we receive the variable OAuth token
146+ // from the CallCredentials, and not the one set as static credentials.
147+ .intercept (
148+ new ServerInterceptor () {
149+ @ Override
150+ public <ReqT , RespT > ServerCall .Listener <ReqT > interceptCall (
151+ ServerCall <ReqT , RespT > call ,
152+ Metadata headers ,
153+ ServerCallHandler <ReqT , RespT > next ) {
154+ String auth =
155+ headers .get (Key .of ("authorization" , Metadata .ASCII_STRING_MARSHALLER ));
156+ assertThat (auth ).isEqualTo ("Bearer " + VARIABLE_OAUTH_TOKEN );
157+ return Contexts .interceptCall (Context .current (), call , headers , next );
158+ }
159+ })
114160 .build ()
115161 .start ();
162+ optionsMap .put (Option .CHANNEL_HINT , Long .valueOf (1L ));
116163 }
117164
118165 @ After
@@ -229,6 +276,55 @@ && getNumberOfThreadsWithName(SPANNER_THREAD_NAME, false)
229276 assertThat (getNumberOfThreadsWithName (SPANNER_THREAD_NAME , true ), is (equalTo (0 )));
230277 }
231278
279+ @ Test
280+ public void testCallCredentialsProviderPreferenceAboveCredentials () {
281+ SpannerOptions options =
282+ SpannerOptions .newBuilder ()
283+ .setCredentials (STATIC_CREDENTIALS )
284+ .setCallCredentialsProvider (
285+ new CallCredentialsProvider () {
286+ @ Override
287+ public CallCredentials getCallCredentials () {
288+ return MoreCallCredentials .from (VARIABLE_CREDENTIALS );
289+ }
290+ })
291+ .build ();
292+ GapicSpannerRpc rpc = new GapicSpannerRpc (options );
293+ // GoogleAuthLibraryCallCredentials doesn't implement equals, so we can only check for the
294+ // existence.
295+ assertThat (rpc .newCallContext (optionsMap , "/some/resource" ).getCallOptions ().getCredentials ())
296+ .isNotNull ();
297+ rpc .shutdown ();
298+ }
299+
300+ @ Test
301+ public void testCallCredentialsProviderReturnsNull () {
302+ SpannerOptions options =
303+ SpannerOptions .newBuilder ()
304+ .setCredentials (STATIC_CREDENTIALS )
305+ .setCallCredentialsProvider (
306+ new CallCredentialsProvider () {
307+ @ Override
308+ public CallCredentials getCallCredentials () {
309+ return null ;
310+ }
311+ })
312+ .build ();
313+ GapicSpannerRpc rpc = new GapicSpannerRpc (options );
314+ assertThat (rpc .newCallContext (optionsMap , "/some/resource" ).getCallOptions ().getCredentials ())
315+ .isNull ();
316+ rpc .shutdown ();
317+ }
318+
319+ @ Test
320+ public void testNoCallCredentials () {
321+ SpannerOptions options = SpannerOptions .newBuilder ().setCredentials (STATIC_CREDENTIALS ).build ();
322+ GapicSpannerRpc rpc = new GapicSpannerRpc (options );
323+ assertThat (rpc .newCallContext (optionsMap , "/some/resource" ).getCallOptions ().getCredentials ())
324+ .isNull ();
325+ rpc .shutdown ();
326+ }
327+
232328 @ SuppressWarnings ("rawtypes" )
233329 private SpannerOptions createSpannerOptions () {
234330 String endpoint = address .getHostString () + ":" + server .getPort ();
@@ -244,7 +340,17 @@ public ManagedChannelBuilder apply(ManagedChannelBuilder input) {
244340 }
245341 })
246342 .setHost ("http://" + endpoint )
247- .setCredentials (NoCredentials .getInstance ())
343+ // Set static credentials that will return the static OAuth test token.
344+ .setCredentials (STATIC_CREDENTIALS )
345+ // Also set a CallCredentialsProvider. These credentials should take precedence above
346+ // the static credentials.
347+ .setCallCredentialsProvider (
348+ new CallCredentialsProvider () {
349+ @ Override
350+ public CallCredentials getCallCredentials () {
351+ return MoreCallCredentials .from (VARIABLE_CREDENTIALS );
352+ }
353+ })
248354 .build ();
249355 }
250356
0 commit comments