File tree

4 files changed

+155
-3
lines changed

4 files changed

+155
-3
lines changed
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@
8080
<groupId>io.grpc</groupId>
8181
<artifactId>grpc-api</artifactId>
8282
</dependency>
83+
<dependency>
84+
<groupId>io.grpc</groupId>
85+
<artifactId>grpc-auth</artifactId>
86+
</dependency>
8387
<dependency>
8488
<groupId>io.grpc</groupId>
8589
<artifactId>grpc-context</artifactId>
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import com.google.common.base.Preconditions;
3838
import com.google.common.collect.ImmutableMap;
3939
import com.google.common.collect.ImmutableSet;
40+
import io.grpc.CallCredentials;
4041
import io.grpc.ManagedChannelBuilder;
4142
import java.io.IOException;
4243
import java.net.MalformedURLException;
@@ -72,6 +73,16 @@ public class SpannerOptions extends ServiceOptions<Spanner, SpannerOptions> {
7273
private final InstanceAdminStubSettings instanceAdminStubSettings;
7374
private final DatabaseAdminStubSettings databaseAdminStubSettings;
7475
private final Duration partitionedDmlTimeout;
76+
private final CallCredentialsProvider callCredentialsProvider;
77+
78+
/**
79+
* Interface that can be used to provide {@link CallCredentials} instead of {@link Credentials} to
80+
* {@link SpannerOptions}.
81+
*/
82+
public static interface CallCredentialsProvider {
83+
/** Return the {@link CallCredentials} to use for a gRPC call. */
84+
CallCredentials getCallCredentials();
85+
}
7586

7687
/** Default implementation of {@code SpannerFactory}. */
7788
private static class DefaultSpannerFactory implements SpannerFactory {
@@ -119,6 +130,7 @@ private SpannerOptions(Builder builder) {
119130
throw SpannerExceptionFactory.newSpannerException(e);
120131
}
121132
partitionedDmlTimeout = builder.partitionedDmlTimeout;
133+
callCredentialsProvider = builder.callCredentialsProvider;
122134
}
123135

124136
/** Builder for {@link SpannerOptions} instances. */
@@ -150,6 +162,7 @@ public static class Builder
150162
private DatabaseAdminStubSettings.Builder databaseAdminStubSettingsBuilder =
151163
DatabaseAdminStubSettings.newBuilder();
152164
private Duration partitionedDmlTimeout = Duration.ofHours(2L);
165+
private CallCredentialsProvider callCredentialsProvider;
153166
private String emulatorHost = System.getenv("SPANNER_EMULATOR_HOST");
154167

155168
private Builder() {}
@@ -164,6 +177,7 @@ private Builder() {}
164177
this.instanceAdminStubSettingsBuilder = options.instanceAdminStubSettings.toBuilder();
165178
this.databaseAdminStubSettingsBuilder = options.databaseAdminStubSettings.toBuilder();
166179
this.partitionedDmlTimeout = options.partitionedDmlTimeout;
180+
this.callCredentialsProvider = options.callCredentialsProvider;
167181
this.channelProvider = options.channelProvider;
168182
this.channelConfigurator = options.channelConfigurator;
169183
this.interceptorProvider = options.interceptorProvider;
@@ -355,6 +369,17 @@ public Builder setPartitionedDmlTimeout(Duration timeout) {
355369
return this;
356370
}
357371

372+
/**
373+
* Sets a {@link CallCredentialsProvider} that can deliver {@link CallCredentials} to use on a
374+
* per-gRPC basis. Any credentials returned by this {@link CallCredentialsProvider} will have
375+
* preference above any {@link Credentials} that may have been set on the {@link SpannerOptions}
376+
* instance.
377+
*/
378+
public Builder setCallCredentialsProvider(CallCredentialsProvider callCredentialsProvider) {
379+
this.callCredentialsProvider = callCredentialsProvider;
380+
return this;
381+
}
382+
358383
/**
359384
* Specifying this will allow the client to prefetch up to {@code prefetchChunks} {@code
360385
* PartialResultSet} chunks for each read and query. The data size of each chunk depends on the
@@ -452,6 +477,10 @@ public Duration getPartitionedDmlTimeout() {
452477
return partitionedDmlTimeout;
453478
}
454479

480+
public CallCredentialsProvider getCallCredentialsProvider() {
481+
return callCredentialsProvider;
482+
}
483+
455484
public int getPrefetchChunks() {
456485
return prefetchChunks;
457486
}
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,14 @@
4040
import com.google.cloud.spanner.SpannerException;
4141
import com.google.cloud.spanner.SpannerExceptionFactory;
4242
import com.google.cloud.spanner.SpannerOptions;
43+
import com.google.cloud.spanner.SpannerOptions.CallCredentialsProvider;
4344
import com.google.cloud.spanner.admin.database.v1.stub.DatabaseAdminStub;
4445
import com.google.cloud.spanner.admin.database.v1.stub.GrpcDatabaseAdminStub;
4546
import com.google.cloud.spanner.admin.instance.v1.stub.GrpcInstanceAdminStub;
4647
import com.google.cloud.spanner.admin.instance.v1.stub.InstanceAdminStub;
4748
import com.google.cloud.spanner.v1.stub.GrpcSpannerStub;
4849
import com.google.cloud.spanner.v1.stub.SpannerStub;
50+
import com.google.common.annotations.VisibleForTesting;
4951
import com.google.common.base.MoreObjects;
5052
import com.google.common.base.Preconditions;
5153
import com.google.common.util.concurrent.ThreadFactoryBuilder;
@@ -99,6 +101,7 @@
99101
import com.google.spanner.v1.RollbackRequest;
100102
import com.google.spanner.v1.Session;
101103
import com.google.spanner.v1.Transaction;
104+
import io.grpc.CallCredentials;
102105
import io.grpc.Context;
103106
import java.io.UnsupportedEncodingException;
104107
import java.net.URLDecoder;
@@ -174,6 +177,7 @@ private synchronized void shutdown() {
174177
private final String projectId;
175178
private final String projectName;
176179
private final SpannerMetadataProvider metadataProvider;
180+
private final CallCredentialsProvider callCredentialsProvider;
177181
private final Duration waitTimeout =
178182
systemProperty(PROPERTY_TIMEOUT_SECONDS, DEFAULT_TIMEOUT_SECONDS);
179183
private final Duration idleTimeout =
@@ -216,6 +220,7 @@ public GapicSpannerRpc(final SpannerOptions options) {
216220
SpannerMetadataProvider.create(
217221
mergedHeaderProvider.getHeaders(),
218222
internalHeaderProviderBuilder.getResourceHeaderKey());
223+
this.callCredentialsProvider = options.getCallCredentialsProvider();
219224

220225
// Create a managed executor provider.
221226
this.executorProvider =
@@ -702,7 +707,8 @@ private static <T> T get(final Future<T> future) throws SpannerException {
702707
}
703708
}
704709

705-
private GrpcCallContext newCallContext(@Nullable Map<Option, ?> options, String resource) {
710+
@VisibleForTesting
711+
GrpcCallContext newCallContext(@Nullable Map<Option, ?> options, String resource) {
706712
return newCallContext(options, resource, null);
707713
}
708714

@@ -716,6 +722,13 @@ private GrpcCallContext newCallContext(
716722
if (timeout != null) {
717723
context = context.withTimeout(timeout);
718724
}
725+
if (callCredentialsProvider != null) {
726+
CallCredentials callCredentials = callCredentialsProvider.getCallCredentials();
727+
if (callCredentials != null) {
728+
context =
729+
context.withCallOptions(context.getCallOptions().withCallCredentials(callCredentials));
730+
}
731+
}
719732
return context.withStreamWaitTimeout(waitTimeout).withStreamIdleTimeout(idleTimeout);
720733
}
721734

Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616

1717
package com.google.cloud.spanner.spi.v1;
1818

19+
import static com.google.common.truth.Truth.assertThat;
1920
import static org.hamcrest.CoreMatchers.equalTo;
2021
import static org.hamcrest.CoreMatchers.is;
2122
import static org.hamcrest.MatcherAssert.assertThat;
2223

2324
import com.google.api.core.ApiFunction;
24-
import com.google.cloud.NoCredentials;
25+
import com.google.auth.oauth2.AccessToken;
26+
import com.google.auth.oauth2.OAuth2Credentials;
2527
import com.google.cloud.spanner.DatabaseAdminClient;
2628
import com.google.cloud.spanner.DatabaseClient;
2729
import com.google.cloud.spanner.DatabaseId;
@@ -31,9 +33,11 @@
3133
import com.google.cloud.spanner.ResultSet;
3234
import com.google.cloud.spanner.Spanner;
3335
import com.google.cloud.spanner.SpannerOptions;
36+
import com.google.cloud.spanner.SpannerOptions.CallCredentialsProvider;
3437
import com.google.cloud.spanner.Statement;
3538
import com.google.cloud.spanner.admin.database.v1.MockDatabaseAdminImpl;
3639
import com.google.cloud.spanner.admin.instance.v1.MockInstanceAdminImpl;
40+
import com.google.cloud.spanner.spi.v1.SpannerRpc.Option;
3741
import com.google.common.base.Stopwatch;
3842
import com.google.protobuf.ListValue;
3943
import com.google.spanner.admin.database.v1.Database;
@@ -45,13 +49,24 @@
4549
import com.google.spanner.v1.StructType;
4650
import com.google.spanner.v1.StructType.Field;
4751
import com.google.spanner.v1.TypeCode;
52+
import io.grpc.CallCredentials;
53+
import io.grpc.Context;
54+
import io.grpc.Contexts;
4855
import io.grpc.ManagedChannelBuilder;
56+
import io.grpc.Metadata;
57+
import io.grpc.Metadata.Key;
4958
import io.grpc.Server;
59+
import io.grpc.ServerCall;
60+
import io.grpc.ServerCallHandler;
61+
import io.grpc.ServerInterceptor;
62+
import io.grpc.auth.MoreCallCredentials;
5063
import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder;
5164
import java.io.IOException;
5265
import java.net.InetSocketAddress;
5366
import java.util.ArrayList;
67+
import java.util.HashMap;
5468
import java.util.List;
69+
import java.util.Map;
5570
import java.util.concurrent.TimeUnit;
5671
import java.util.regex.Pattern;
5772
import org.junit.After;
@@ -91,11 +106,27 @@ public class GapicSpannerRpcTest {
91106
.build())
92107
.setMetadata(SELECT1AND2_METADATA)
93108
.build();
109+
private static final String STATIC_OAUTH_TOKEN = "STATIC_TEST_OAUTH_TOKEN";
110+
private static final String VARIABLE_OAUTH_TOKEN = "VARIABLE_TEST_OAUTH_TOKEN";
111+
private static final OAuth2Credentials STATIC_CREDENTIALS =
112+
OAuth2Credentials.create(
113+
new AccessToken(
114+
STATIC_OAUTH_TOKEN,
115+
new java.util.Date(
116+
System.currentTimeMillis() + TimeUnit.MILLISECONDS.convert(1L, TimeUnit.DAYS))));
117+
private static final OAuth2Credentials VARIABLE_CREDENTIALS =
118+
OAuth2Credentials.create(
119+
new AccessToken(
120+
VARIABLE_OAUTH_TOKEN,
121+
new java.util.Date(
122+
System.currentTimeMillis() + TimeUnit.MILLISECONDS.convert(1L, TimeUnit.DAYS))));
123+
94124
private MockSpannerServiceImpl mockSpanner;
95125
private MockInstanceAdminImpl mockInstanceAdmin;
96126
private MockDatabaseAdminImpl mockDatabaseAdmin;
97127
private Server server;
98128
private InetSocketAddress address;
129+
private final Map<SpannerRpc.Option, Object> optionsMap = new HashMap<>();
99130

100131
@Before
101132
public void startServer() throws IOException {
@@ -111,8 +142,24 @@ public void startServer() throws IOException {
111142
.addService(mockSpanner)
112143
.addService(mockInstanceAdmin)
113144
.addService(mockDatabaseAdmin)
145+
// Add a server interceptor that will check that we receive the variable OAuth token
146+
// from the CallCredentials, and not the one set as static credentials.
147+
.intercept(
148+
new ServerInterceptor() {
149+
@Override
150+
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
151+
ServerCall<ReqT, RespT> call,
152+
Metadata headers,
153+
ServerCallHandler<ReqT, RespT> next) {
154+
String auth =
155+
headers.get(Key.of("authorization", Metadata.ASCII_STRING_MARSHALLER));
156+
assertThat(auth).isEqualTo("Bearer " + VARIABLE_OAUTH_TOKEN);
157+
return Contexts.interceptCall(Context.current(), call, headers, next);
158+
}
159+
})
114160
.build()
115161
.start();
162+
optionsMap.put(Option.CHANNEL_HINT, Long.valueOf(1L));
116163
}
117164

118165
@After
@@ -229,6 +276,55 @@ && getNumberOfThreadsWithName(SPANNER_THREAD_NAME, false)
229276
assertThat(getNumberOfThreadsWithName(SPANNER_THREAD_NAME, true), is(equalTo(0)));
230277
}
231278

279+
@Test
280+
public void testCallCredentialsProviderPreferenceAboveCredentials() {
281+
SpannerOptions options =
282+
SpannerOptions.newBuilder()
283+
.setCredentials(STATIC_CREDENTIALS)
284+
.setCallCredentialsProvider(
285+
new CallCredentialsProvider() {
286+
@Override
287+
public CallCredentials getCallCredentials() {
288+
return MoreCallCredentials.from(VARIABLE_CREDENTIALS);
289+
}
290+
})
291+
.build();
292+
GapicSpannerRpc rpc = new GapicSpannerRpc(options);
293+
// GoogleAuthLibraryCallCredentials doesn't implement equals, so we can only check for the
294+
// existence.
295+
assertThat(rpc.newCallContext(optionsMap, "/some/resource").getCallOptions().getCredentials())
296+
.isNotNull();
297+
rpc.shutdown();
298+
}
299+
300+
@Test
301+
public void testCallCredentialsProviderReturnsNull() {
302+
SpannerOptions options =
303+
SpannerOptions.newBuilder()
304+
.setCredentials(STATIC_CREDENTIALS)
305+
.setCallCredentialsProvider(
306+
new CallCredentialsProvider() {
307+
@Override
308+
public CallCredentials getCallCredentials() {
309+
return null;
310+
}
311+
})
312+
.build();
313+
GapicSpannerRpc rpc = new GapicSpannerRpc(options);
314+
assertThat(rpc.newCallContext(optionsMap, "/some/resource").getCallOptions().getCredentials())
315+
.isNull();
316+
rpc.shutdown();
317+
}
318+
319+
@Test
320+
public void testNoCallCredentials() {
321+
SpannerOptions options = SpannerOptions.newBuilder().setCredentials(STATIC_CREDENTIALS).build();
322+
GapicSpannerRpc rpc = new GapicSpannerRpc(options);
323+
assertThat(rpc.newCallContext(optionsMap, "/some/resource").getCallOptions().getCredentials())
324+
.isNull();
325+
rpc.shutdown();
326+
}
327+
232328
@SuppressWarnings("rawtypes")
233329
private SpannerOptions createSpannerOptions() {
234330
String endpoint = address.getHostString() + ":" + server.getPort();
@@ -244,7 +340,17 @@ public ManagedChannelBuilder apply(ManagedChannelBuilder input) {
244340
}
245341
})
246342
.setHost("http://" + endpoint)
247-
.setCredentials(NoCredentials.getInstance())
343+
// Set static credentials that will return the static OAuth test token.
344+
.setCredentials(STATIC_CREDENTIALS)
345+
// Also set a CallCredentialsProvider. These credentials should take precedence above
346+
// the static credentials.
347+
.setCallCredentialsProvider(
348+
new CallCredentialsProvider() {
349+
@Override
350+
public CallCredentials getCallCredentials() {
351+
return MoreCallCredentials.from(VARIABLE_CREDENTIALS);
352+
}
353+
})
248354
.build();
249355
}
250356

0 commit comments

Comments
 (0)