diff --git a/google-auth-library-java/cab-token-generator/java/com/google/auth/credentialaccessboundary/ClientSideCredentialAccessBoundaryFactory.java b/google-auth-library-java/cab-token-generator/java/com/google/auth/credentialaccessboundary/ClientSideCredentialAccessBoundaryFactory.java index 46ab2fd2d24c..cb1e1caa89b8 100644 --- a/google-auth-library-java/cab-token-generator/java/com/google/auth/credentialaccessboundary/ClientSideCredentialAccessBoundaryFactory.java +++ b/google-auth-library-java/cab-token-generator/java/com/google/auth/credentialaccessboundary/ClientSideCredentialAccessBoundaryFactory.java @@ -53,6 +53,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Strings; import com.google.common.util.concurrent.AbstractFuture; +import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListenableFutureTask; @@ -78,6 +79,7 @@ import java.util.Date; import java.util.List; import java.util.concurrent.ExecutionException; +import javax.annotation.Nullable; /** * A factory for generating downscoped access tokens using a client-side approach. @@ -246,7 +248,7 @@ void refreshCredentialsIfRequired() throws IOException { } try { // Wait for the refresh task to complete. - currentRefreshTask.get(); + currentRefreshTask.task.get(); } catch (InterruptedException e) { // Restore the interrupted status and throw an exception. Thread.currentThread().interrupt(); @@ -493,18 +495,31 @@ class RefreshTask extends AbstractFuture implements Run this.task = task; this.isNew = isNew; - // Single listener to guarantee that finishRefreshTask updates the internal state BEFORE - // the outer future completes and unblocks waiters. + // Add listener to update factory's credentials when the task completes. task.addListener( () -> { try { finishRefreshTask(task); - RefreshTask.this.set(Futures.getDone(task)); } catch (ExecutionException e) { Throwable cause = e.getCause(); - RefreshTask.this.setException(cause != null ? cause : e); - } catch (Throwable t) { - RefreshTask.this.setException(t); + RefreshTask.this.setException(cause); + } + }, + MoreExecutors.directExecutor()); + + // Add callback to set the result or exception based on the outcome. + Futures.addCallback( + task, + new FutureCallback() { + @Override + public void onSuccess(IntermediateCredentials result) { + RefreshTask.this.set(result); + } + + @Override + public void onFailure(@Nullable Throwable t) { + RefreshTask.this.setException( + t != null ? t : new IOException("Refresh failed with null Throwable.")); } }, MoreExecutors.directExecutor()); diff --git a/google-auth-library-java/cab-token-generator/javatests/com/google/auth/credentialaccessboundary/ClientSideCredentialAccessBoundaryFactoryTest.java b/google-auth-library-java/cab-token-generator/javatests/com/google/auth/credentialaccessboundary/ClientSideCredentialAccessBoundaryFactoryTest.java index bfe9077c990f..a1714a9ba92f 100644 --- a/google-auth-library-java/cab-token-generator/javatests/com/google/auth/credentialaccessboundary/ClientSideCredentialAccessBoundaryFactoryTest.java +++ b/google-auth-library-java/cab-token-generator/javatests/com/google/auth/credentialaccessboundary/ClientSideCredentialAccessBoundaryFactoryTest.java @@ -988,59 +988,4 @@ void generateToken_withMalformSessionKey_failure() throws Exception { assertThrows(GeneralSecurityException.class, () -> factory.generateToken(accessBoundary)); } - - @Test - void generateToken_freshInstance_concurrent_noNpe() throws Exception { - for (int run = 0; run < 10; run++) { // Run 10 times in a single test instance to save time - GoogleCredentials sourceCredentials = - getServiceAccountSourceCredentials(mockTokenServerTransportFactory); - ClientSideCredentialAccessBoundaryFactory factory = - ClientSideCredentialAccessBoundaryFactory.newBuilder() - .setSourceCredential(sourceCredentials) - .setHttpTransportFactory(mockStsTransportFactory) - .build(); - - CredentialAccessBoundary.Builder cabBuilder = CredentialAccessBoundary.newBuilder(); - CredentialAccessBoundary accessBoundary = - cabBuilder - .addRule( - CredentialAccessBoundary.AccessBoundaryRule.newBuilder() - .setAvailableResource("resource") - .setAvailablePermissions(ImmutableList.of("role")) - .build()) - .build(); - - int numThreads = 5; - CountDownLatch latch = new CountDownLatch(numThreads); - java.util.concurrent.atomic.AtomicInteger npeCount = - new java.util.concurrent.atomic.AtomicInteger(); - java.util.concurrent.ExecutorService executor = - java.util.concurrent.Executors.newFixedThreadPool(numThreads); - - try { - for (int i = 0; i < numThreads; i++) { - executor.submit( - () -> { - try { - latch.countDown(); - latch.await(); - factory.generateToken(accessBoundary); - } catch (NullPointerException e) { - npeCount.incrementAndGet(); - } catch (Exception e) { - // Ignore other exceptions for the sake of the race reproduction - } - }); - } - } finally { - executor.shutdown(); - executor.awaitTermination(5, java.util.concurrent.TimeUnit.SECONDS); - } - - org.junit.jupiter.api.Assertions.assertEquals( - 0, - npeCount.get(), - "Expected zero NullPointerExceptions due to the race condition, but some were thrown."); - } - } }