Skip to content

Commit e7c7bdd

Browse files
authored
feat: improve reliability of refresh operations (#1147)
Update refresh calculation algorithm to support 24-hour ephemeral certs
1 parent 2c5d58c commit e7c7bdd

File tree

9 files changed

+96
-43
lines changed

9 files changed

+96
-43
lines changed

core/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,12 @@
124124
</dependencyManagement>
125125

126126
<dependencies>
127+
<dependency>
128+
<groupId>dev.failsafe</groupId>
129+
<artifactId>failsafe</artifactId>
130+
<version>3.3.0</version>
131+
</dependency>
132+
127133
<dependency>
128134
<groupId>junit</groupId>
129135
<artifactId>junit</artifactId>

core/src/main/java/com/google/cloud/sql/core/CloudSqlInstance.java

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@
3838
import com.google.common.util.concurrent.Futures;
3939
import com.google.common.util.concurrent.ListenableFuture;
4040
import com.google.common.util.concurrent.ListeningScheduledExecutorService;
41-
import com.google.common.util.concurrent.RateLimiter;
4241
import com.google.common.util.concurrent.SettableFuture;
4342
import com.google.common.util.concurrent.Uninterruptibles;
4443
import com.google.errorprone.annotations.concurrent.GuardedBy;
44+
import dev.failsafe.RateLimiter;
4545
import java.io.ByteArrayInputStream;
4646
import java.io.IOException;
4747
import java.nio.charset.StandardCharsets;
@@ -95,12 +95,7 @@ class CloudSqlInstance {
9595
// defaultRefreshBuffer is the minimum amount of time for which a
9696
// certificate must be valid to ensure the next refresh attempt has adequate
9797
// time to complete.
98-
private static final Duration DEFAULT_REFRESH_BUFFER = Duration.ofMinutes(5);
99-
// iamAuthRefreshBuffer is the minimum amount of time for which a
100-
// certificate holding an Access Token must be valid. Because some token
101-
// sources are refreshed with only ~60 seconds before expiration, this value
102-
// must be smaller than the defaultRefreshBuffer.
103-
private static final Duration IAM_AUTH_REFRESH_BUFFER = Duration.ofSeconds(55);
98+
private static final Duration DEFAULT_REFRESH_BUFFER = Duration.ofMinutes(4);
10499
private final ListeningScheduledExecutorService executor;
105100
private final SQLAdmin apiClient;
106101
private final boolean enableIamAuth;
@@ -113,7 +108,8 @@ class CloudSqlInstance {
113108
private final ListenableFuture<KeyPair> keyPair;
114109
private final Object instanceDataGuard = new Object();
115110
// Limit forced refreshes to 1 every minute.
116-
private final RateLimiter forcedRenewRateLimiter = RateLimiter.create(1.0 / 60.0);
111+
private final RateLimiter<Object> forcedRenewRateLimiter = RateLimiter.burstyBuilder(2,
112+
Duration.ofSeconds(30)).build();
117113
@GuardedBy("instanceDataGuard")
118114
private ListenableFuture<InstanceData> currentInstanceData;
119115
@GuardedBy("instanceDataGuard")
@@ -134,7 +130,7 @@ class CloudSqlInstance {
134130
boolean enableIamAuth,
135131
CredentialFactory tokenSourceFactory,
136132
ListeningScheduledExecutorService executor,
137-
ListenableFuture<KeyPair> keyPair) throws IOException {
133+
ListenableFuture<KeyPair> keyPair) throws IOException, InterruptedException {
138134

139135
Matcher matcher = CONNECTION_NAME.matcher(connectionName);
140136
checkArgument(
@@ -272,6 +268,22 @@ static GoogleCredentials getDownscopedCredentials(OAuth2Credentials credentials)
272268
return downscoped;
273269
}
274270

271+
static long secondsUntilRefresh(Date expiration) {
272+
Duration timeUntilExp = Duration.between(Instant.now(), expiration.toInstant());
273+
274+
if (timeUntilExp.compareTo(Duration.ofHours(1)) < 0) {
275+
if (timeUntilExp.compareTo(DEFAULT_REFRESH_BUFFER) < 0) {
276+
// If the time until the certificate expires is less the refresh buffer, schedule the
277+
// refresh immediately
278+
return 0;
279+
}
280+
// Otherwise schedule a refresh in (timeUntilExp - buffer) seconds
281+
return timeUntilExp.minus(DEFAULT_REFRESH_BUFFER).getSeconds();
282+
}
283+
// If the time until the certificate expires is longer than an hour, return timeUntilExp//2
284+
return timeUntilExp.dividedBy(2).getSeconds();
285+
}
286+
275287
private OAuth2Credentials parseCredentials(HttpRequestInitializer source) {
276288
if (source instanceof HttpCredentialsAdapter) {
277289
HttpCredentialsAdapter adapter = (HttpCredentialsAdapter) source;
@@ -361,7 +373,7 @@ String getPreferredIp(List<String> preferredTypes) {
361373
*
362374
* @return {@code true} if successfully scheduled, or {@code false} otherwise.
363375
*/
364-
boolean forceRefresh() {
376+
boolean forceRefresh() throws InterruptedException {
365377
synchronized (instanceDataGuard) {
366378
// If a scheduled refresh hasn't started, perform one immediately
367379
if (nextInstanceData.cancel(false)) {
@@ -380,9 +392,9 @@ boolean forceRefresh() {
380392
* value of currentInstanceData and schedules the next refresh shortly before the information
381393
* would expire.
382394
*/
383-
private ListenableFuture<InstanceData> performRefresh() {
395+
private ListenableFuture<InstanceData> performRefresh() throws InterruptedException {
384396
// To avoid unreasonable SQL Admin API usage, use a rate limit to throttle our usage.
385-
forcedRenewRateLimiter.acquire(1);
397+
forcedRenewRateLimiter.acquirePermit();
386398
// Use the Cloud SQL Admin API to return the Metadata and Certificate
387399
ListenableFuture<Metadata> metadataFuture = executor.submit(this::fetchMetadata);
388400
ListenableFuture<Certificate> ephemeralCertificateFuture =
@@ -432,7 +444,7 @@ public void onSuccess(InstanceData instanceData) {
432444
// schedule a replacement before the SSLContext expires;
433445
nextInstanceData = executor
434446
.schedule(() -> performRefresh(),
435-
secondsUntilRefresh(),
447+
secondsUntilRefresh(getInstanceData().getExpiration()),
436448
TimeUnit.SECONDS);
437449
}
438450
}
@@ -452,7 +464,11 @@ public void onFailure(Throwable t) {
452464
// replace current if it is expired or invalid
453465
currentInstanceData = refreshFuture;
454466
}
455-
nextInstanceData = Futures.immediateFuture(performRefresh());
467+
try {
468+
nextInstanceData = Futures.immediateFuture(performRefresh());
469+
} catch (InterruptedException e) {
470+
throw new RuntimeException(e);
471+
}
456472
}
457473
}
458474
}, executor);
@@ -632,23 +648,6 @@ private Optional<Date> getTokenExpirationTime(Credential credentials) {
632648
.map(expirationTime -> new Date(expirationTime));
633649
}
634650

635-
private long secondsUntilRefresh() {
636-
Duration refreshBuffer = enableIamAuth ? IAM_AUTH_REFRESH_BUFFER : DEFAULT_REFRESH_BUFFER;
637-
638-
Date expiration = getInstanceData().getExpiration();
639-
640-
Duration timeUntilRefresh = Duration.between(Instant.now(), expiration.toInstant())
641-
.minus(refreshBuffer);
642-
643-
if (timeUntilRefresh.isNegative()) {
644-
// If the time until the certificate expires is less than the buffer, schedule the refresh
645-
// closer to the expiration time
646-
timeUntilRefresh = Duration.between(Instant.now(), expiration.toInstant())
647-
.minus(Duration.ofSeconds(5));
648-
}
649-
return timeUntilRefresh.getSeconds();
650-
}
651-
652651
/**
653652
* Checks for common errors that can occur when interacting with the Cloud SQL Admin API, and adds
654653
* additional context to help the user troubleshoot them.

core/src/main/java/com/google/cloud/sql/core/CoreSocketFactory.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ private CloudSqlInstance getCloudSqlInstance(String instanceName, boolean enable
164164
try {
165165
return new CloudSqlInstance(k, adminApi, enableIamAuth, credentialFactory, executor,
166166
localKeyPair);
167-
} catch (IOException e) {
167+
} catch (IOException | InterruptedException e) {
168168
throw new RuntimeException(e);
169169
}
170170
});
@@ -211,7 +211,7 @@ private static String getUnixSocketArg(Properties props) {
211211
/**
212212
* Creates a socket representing a connection to a Cloud SQL instance.
213213
*/
214-
public static Socket connect(Properties props) throws IOException {
214+
public static Socket connect(Properties props) throws IOException, InterruptedException {
215215
return connect(props, null);
216216
}
217217

@@ -225,7 +225,8 @@ public static Socket connect(Properties props) throws IOException {
225225
* @return the newly created Socket.
226226
* @throws IOException if error occurs during socket creation.
227227
*/
228-
public static Socket connect(Properties props, String unixPathSuffix) throws IOException {
228+
public static Socket connect(Properties props, String unixPathSuffix)
229+
throws IOException, InterruptedException {
229230
// Gather parameters
230231
final String csqlInstanceName = props.getProperty(CLOUD_SQL_INSTANCE_PROPERTY);
231232
final boolean enableIamAuth = Boolean.parseBoolean(props.getProperty("enableIamAuth"));
@@ -303,7 +304,7 @@ private String getHostIp(String instanceName, List<String> ipTypes) {
303304
// TODO(berezv): separate creating socket and performing connection to make it easier to test
304305
@VisibleForTesting
305306
Socket createSslSocket(String instanceName, List<String> ipTypes, boolean enableIamAuth)
306-
throws IOException {
307+
throws IOException, InterruptedException {
307308
CloudSqlInstance instance = getCloudSqlInstance(instanceName, enableIamAuth);
308309

309310
try {
@@ -327,7 +328,8 @@ Socket createSslSocket(String instanceName, List<String> ipTypes, boolean enable
327328
}
328329
}
329330

330-
Socket createSslSocket(String instanceName, List<String> ipTypes) throws IOException {
331+
Socket createSslSocket(String instanceName, List<String> ipTypes)
332+
throws IOException, InterruptedException {
331333
return createSslSocket(instanceName, ipTypes, false);
332334
}
333335

core/src/test/java/com/google/cloud/sql/core/CloudSqlInstanceTest.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
import com.google.auth.oauth2.GoogleCredentials;
2525
import com.google.auth.oauth2.OAuth2Credentials;
2626
import java.io.IOException;
27+
import java.time.Duration;
28+
import java.time.Instant;
29+
import java.util.Date;
30+
import org.junit.Assert;
2731
import org.junit.Before;
2832
import org.junit.Test;
2933
import org.junit.runner.RunWith;
@@ -88,4 +92,23 @@ public void throwsErrorIamAuthNotSupported() {
8892
}
8993
}
9094

95+
@Test
96+
public void timeUntilRefreshImmediate() {
97+
Date expiration = Date.from(Instant.now().plus(Duration.ofMinutes(3)));
98+
assertThat(CloudSqlInstance.secondsUntilRefresh(expiration)).isEqualTo(0L);
99+
}
100+
101+
@Test
102+
public void timeUntilRefresh1Hr() {
103+
Date expiration = Date.from(Instant.now().plus(Duration.ofMinutes(59)));
104+
Long expected = Duration.ofMinutes(59).minus(Duration.ofMinutes(4)).getSeconds();
105+
Assert.assertEquals(CloudSqlInstance.secondsUntilRefresh(expiration), expected, 1);
106+
}
107+
108+
@Test
109+
public void timeUntilRefresh24Hr() {
110+
Date expiration = Date.from(Instant.now().plus(Duration.ofHours(23)));
111+
Long expected = Duration.ofHours(23).dividedBy(2).getSeconds();
112+
Assert.assertEquals(CloudSqlInstance.secondsUntilRefresh(expiration), expected, 1);
113+
}
91114
}

core/src/test/java/com/google/cloud/sql/core/CoreSocketFactoryTest.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ public void create_throwsErrorForInvalidInstanceName() throws IOException {
201201
try {
202202
coreSocketFactory.createSslSocket("myProject", Arrays.asList("PRIMARY"));
203203
fail();
204-
} catch (IllegalArgumentException e) {
204+
} catch (IllegalArgumentException | InterruptedException e) {
205205
assertThat(e).hasMessageThat().contains("Cloud SQL connection name is invalid");
206206
}
207207

@@ -210,6 +210,8 @@ public void create_throwsErrorForInvalidInstanceName() throws IOException {
210210
fail();
211211
} catch (IllegalArgumentException e) {
212212
assertThat(e).hasMessageThat().contains("Cloud SQL connection name is invalid");
213+
} catch (InterruptedException e) {
214+
throw new RuntimeException(e);
213215
}
214216
}
215217

@@ -225,6 +227,8 @@ public void create_throwsErrorForInvalidInstanceRegion() throws IOException {
225227
assertThat(e)
226228
.hasMessageThat()
227229
.contains("The region specified for the Cloud SQL instance is incorrect");
230+
} catch (InterruptedException e) {
231+
throw new RuntimeException(e);
228232
}
229233
}
230234

@@ -350,7 +354,7 @@ public void create_adminApiNotEnabled() throws IOException {
350354
coreSocketFactory.createSslSocket(
351355
"NotMyProject:myRegion:myInstance", Arrays.asList("PRIMARY"));
352356
fail("Expected RuntimeException");
353-
} catch (RuntimeException e) {
357+
} catch (RuntimeException | InterruptedException e) {
354358
// TODO(berezv): should we throw something more specific than RuntimeException?
355359
assertThat(e)
356360
.hasMessageThat()
@@ -377,6 +381,8 @@ public void create_notAuthorized() throws IOException {
377381
String.format(
378382
"[%s] The Cloud SQL Instance does not exist or your account is not authorized",
379383
"myProject:myRegion:NotMyInstance"));
384+
} catch (InterruptedException e) {
385+
throw new RuntimeException(e);
380386
}
381387
}
382388

jdbc/mysql-j-5/src/main/java/com/google/cloud/sql/mysql/SocketFactory.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@ public class SocketFactory implements com.mysql.jdbc.SocketFactory {
3838

3939
@Override
4040
public Socket connect(String hostname, int portNumber, Properties props) throws IOException {
41-
socket = CoreSocketFactory.connect(props);
41+
try {
42+
socket = CoreSocketFactory.connect(props);
43+
} catch (InterruptedException e) {
44+
throw new RuntimeException(e);
45+
}
4246
return socket;
4347
}
4448

jdbc/mysql-j-8/src/main/java/com/google/cloud/sql/mysql/SocketFactory.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,20 @@ public class SocketFactory implements com.mysql.cj.protocol.SocketFactory {
3939
@Override
4040
public <T extends Closeable> T connect(
4141
String host, int portNumber, PropertySet props, int loginTimeout) throws IOException {
42-
return connect(host, portNumber, props.exposeAsProperties(), loginTimeout);
42+
try {
43+
return connect(host, portNumber, props.exposeAsProperties(), loginTimeout);
44+
} catch (InterruptedException e) {
45+
throw new RuntimeException(e);
46+
}
4347
}
4448

4549
/**
4650
* Implements the interface for com.mysql.cj.protocol.SocketFactory for mysql-connector-java prior
4751
* to version 8.0.13. This change is required for backwards compatibility.
4852
*/
4953
public <T extends Closeable> T connect(
50-
String host, int portNumber, Properties props, int loginTimeout) throws IOException {
54+
String host, int portNumber, Properties props, int loginTimeout)
55+
throws IOException, InterruptedException {
5156
@SuppressWarnings("unchecked")
5257
T socket = (T) CoreSocketFactory.connect(props);
5358
return socket;

jdbc/postgres/src/main/java/com/google/cloud/sql/postgres/SocketFactory.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,11 @@ private static Properties createDefaultProperties(String instanceName) {
7373

7474
@Override
7575
public Socket createSocket() throws IOException {
76-
return CoreSocketFactory.connect(props, POSTGRES_SUFFIX);
76+
try {
77+
return CoreSocketFactory.connect(props, POSTGRES_SUFFIX);
78+
} catch (InterruptedException e) {
79+
throw new RuntimeException(e);
80+
}
7781
}
7882

7983
@Override

jdbc/sqlserver/src/main/java/com/google/cloud/sql/sqlserver/SocketFactory.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,11 @@ public SocketFactory(String socketFactoryConstructorArg)
7171

7272
@Override
7373
public Socket createSocket() throws IOException {
74-
return CoreSocketFactory.connect(props);
74+
try {
75+
return CoreSocketFactory.connect(props);
76+
} catch (InterruptedException e) {
77+
throw new RuntimeException(e);
78+
}
7579
}
7680

7781
@Override

0 commit comments

Comments
 (0)