Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
"anonymous", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_USER"));

private final Mono<Authentication> currentAuthenticationMono = ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication)
.defaultIfEmpty(ANONYMOUS_USER_TOKEN);
.mapNotNull(SecurityContext::getAuthentication);

// @formatter:off
private final Mono<String> clientRegistrationIdMono = this.currentAuthenticationMono
Expand All @@ -144,6 +143,8 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements

private ServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository();

private PrincipalResolver principalResolver = (request) -> this.currentAuthenticationMono;

/**
* Constructs a {@code ServerOAuth2AuthorizedClientExchangeFilterFunction} using the
* provided parameters.
Expand Down Expand Up @@ -326,6 +327,15 @@ public void setDefaultClientRegistrationId(String clientRegistrationId) {

@Override
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
// @formatter:off
return this.principalResolver.resolve(request)
.defaultIfEmpty(ANONYMOUS_USER_TOKEN)
.flatMap((authentication) -> doFilter(request, next)
.contextWrite(ReactiveSecurityContextHolder.withAuthentication(authentication)));
// @formatter:on
}

private Mono<ClientResponse> doFilter(ClientRequest request, ExchangeFunction next) {
// @formatter:off
return authorizedClient(request)
.map((authorizedClient) -> bearer(request, authorizedClient))
Expand Down Expand Up @@ -477,13 +487,46 @@ public void setServerSecurityContextRepository(ServerSecurityContextRepository s
this.serverSecurityContextRepository = serverSecurityContextRepository;
}

/**
* Sets the strategy for resolving a {@link Mono} of the {@link Authentication
* principal} from an intercepted request.
* @param principalResolver the strategy for resolving a {@link Mono} of the
* {@link Authentication principal}
* @since 7.1
*/
public void setPrincipalResolver(PrincipalResolver principalResolver) {
Assert.notNull(principalResolver, "principalResolver cannot be null");
this.principalResolver = principalResolver;
}

@FunctionalInterface
private interface ClientResponseHandler {

Mono<ClientResponse> handleResponse(ClientRequest request, Mono<ClientResponse> response);

}

/**
* A strategy for resolving a {@link Mono} of the {@link Authentication principal}
* from an intercepted request.
*
* @since 7.1
*/
@FunctionalInterface
public interface PrincipalResolver {

/**
* Resolve a {@link Mono} of the {@link Authentication principal} from the current
* request, which is used to obtain an {@link OAuth2AuthorizedClient}.
* @param request the intercepted request, containing HTTP method, URI, headers,
* and request attributes
* @return the {@link Mono} of the {@link Authentication principal} to be used for
* resolving an {@link OAuth2AuthorizedClient}
*/
Mono<Authentication> resolve(ClientRequest request);

}

/**
* Forwards authentication and authorization failures to a
* {@link ReactiveOAuth2AuthorizationFailureHandler}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.jspecify.annotations.Nullable;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
import reactor.util.context.Context;
Expand Down Expand Up @@ -122,6 +123,7 @@
* @author Rob Winch
* @author Joe Grandja
* @author Roman Matiushchenko
* @author Evgeniy Cheban
* @since 5.1
* @see OAuth2AuthorizedClientManager
* @see DefaultOAuth2AuthorizedClientManager
Expand Down Expand Up @@ -151,6 +153,13 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();

/*
* For consistency, the default implementation resolves a principal from request
* attributes. Request attributes are populated from Reactor context which is enriched
* in SecurityReactorContextConfiguration.SecurityReactorContextSubscriber
*/
private PrincipalResolver principalResolver = (request) -> getAuthentication(request.attributes());

private OAuth2AuthorizedClientManager authorizedClientManager;

private boolean defaultOAuth2AuthorizedClient;
Expand Down Expand Up @@ -372,6 +381,18 @@ public void setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler aut
this.clientResponseHandler = new AuthorizationFailureForwarder(authorizationFailureHandler);
}

/**
* Sets the strategy for resolving a {@link Authentication principal} from an
* intercepted request.
* @param principalResolver the strategy for resolving a {@link Authentication
* principal}
* @since 7.1
*/
public void setPrincipalResolver(PrincipalResolver principalResolver) {
Assert.notNull(principalResolver, "principalResolver cannot be null");
this.principalResolver = principalResolver;
}

@Override
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
// @formatter:off
Expand Down Expand Up @@ -459,7 +480,7 @@ private String resolveClientRegistrationId(ClientRequest request) {
if (clientRegistrationId == null) {
clientRegistrationId = this.defaultClientRegistrationId;
}
Authentication authentication = getAuthentication(attrs);
Authentication authentication = this.principalResolver.resolve(request);
if (clientRegistrationId == null && this.defaultOAuth2AuthorizedClient
&& authentication instanceof OAuth2AuthenticationToken) {
clientRegistrationId = ((OAuth2AuthenticationToken) authentication).getAuthorizedClientRegistrationId();
Expand All @@ -472,7 +493,7 @@ private Mono<OAuth2AuthorizedClient> authorizeClient(String clientRegistrationId
return Mono.empty();
}
Map<String, Object> attrs = request.attributes();
Authentication authentication = getAuthentication(attrs);
Authentication authentication = this.principalResolver.resolve(request);
if (authentication == null) {
authentication = ANONYMOUS_AUTHENTICATION;
}
Expand All @@ -495,7 +516,7 @@ private Mono<OAuth2AuthorizedClient> reauthorizeClient(OAuth2AuthorizedClient au
return Mono.just(authorizedClient);
}
Map<String, Object> attrs = request.attributes();
Authentication authentication = getAuthentication(attrs);
Authentication authentication = this.principalResolver.resolve(request);
if (authentication == null) {
authentication = createAuthentication(authorizedClient.getPrincipalName());
}
Expand Down Expand Up @@ -567,6 +588,27 @@ public Object getPrincipal() {
};
}

/**
* A strategy for resolving a {@link Authentication principal} from an intercepted
* request.
*
* @since 7.1
*/
@FunctionalInterface
public interface PrincipalResolver {

/**
* Resolve a {@link Authentication principal} from the current request, which is
* used to obtain an {@link OAuth2AuthorizedClient}.
* @param request the intercepted request, containing HTTP method, URI, headers,
* and request attributes
* @return the {@link Mono} of the {@link Authentication principal} to be used for
* resolving an {@link OAuth2AuthorizedClient}
*/
@Nullable Authentication resolve(ClientRequest request);

}

@FunctionalInterface
private interface ClientResponseHandler {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,13 @@ public void setServerSecurityContextRepositoryWhenHandlerIsNullThenThrowIllegalA
.setServerSecurityContextRepository(null));
}

@Test
public void setPrincipalResolverWhenResolverIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientManager)
.setPrincipalResolver(null));
}

@Test
public void filterWhenAuthorizedClientNullThenAuthorizationHeaderNull() {
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build();
Expand Down Expand Up @@ -791,6 +798,38 @@ public void filterWhenClientRegistrationIdFromAuthenticationThenAuthorizedClient
assertThat(getBody(request0)).isEmpty();
}

@Test
public void filterWhenClientRegistrationIdFromAuthenticationAndCustomPrincipalResolverThenAuthorizedClientResolved() {
this.function.setDefaultOAuth2AuthorizedClient(true);
OAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"),
Collections.singletonMap("user", "rob"), "user");
OAuth2AuthenticationToken initialAuthentication = new OAuth2AuthenticationToken(user, user.getAuthorities(),
"initial-registration-id");
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(user, user.getAuthorities(),
this.registration.getRegistrationId());
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName",
this.accessToken);
given(this.authorizedClientRepository.loadAuthorizedClient(this.registration.getRegistrationId(),
authentication, this.serverWebExchange))
.willReturn(Mono.just(authorizedClient));
final ClientRequest clientRequest = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com"))
.build();
this.function.setPrincipalResolver((request) -> Mono.just(authentication));
this.function.filter(clientRequest, this.exchange)
.contextWrite(ReactiveSecurityContextHolder.withAuthentication(initialAuthentication))
.contextWrite(serverWebExchange())
.block();
List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1);
ClientRequest request0 = requests.get(0);
assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com");
assertThat(request0.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request0)).isEmpty();
verify(this.authorizedClientRepository).loadAuthorizedClient(this.registration.getRegistrationId(),
authentication, this.serverWebExchange);
}

@Test
public void filterWhenDefaultOAuth2AuthorizedClientFalseThenEmpty() {
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@

/**
* @author Rob Winch
* @author Evgeniy Cheban
* @since 5.1
*/
@ExtendWith(MockitoExtension.class)
Expand Down Expand Up @@ -217,6 +218,13 @@ public void constructorWhenAuthorizedClientManagerIsNullThenThrowIllegalArgument
.isThrownBy(() -> new ServletOAuth2AuthorizedClientExchangeFilterFunction(null));
}

@Test
public void setPrincipalResolverWhenResolverIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientManager)
.setPrincipalResolver(null));
}

@Test
public void defaultRequestRequestResponseWhenNullRequestContextThenRequestAndResponseNull() {
Map<String, Object> attrs = getDefaultRequestAttributes();
Expand Down Expand Up @@ -620,6 +628,39 @@ public void filterWhenChainedThenDefaultsStillAvailable() throws Exception {
assertThat(getBody(request)).isEmpty();
}

@Test
public void filterWhenClientRegistrationIdFromAuthenticationAndCustomPrincipalResolverThenAuthorizedClientResolved() {
this.function.setDefaultOAuth2AuthorizedClient(true);
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
OAuth2User user = mock(OAuth2User.class);
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
OAuth2AuthenticationToken initialAuthentication = new OAuth2AuthenticationToken(user, authorities,
"initial-registration-id");
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(user, authorities,
this.registration.getRegistrationId());
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName",
this.accessToken);
given(this.authorizedClientRepository.loadAuthorizedClient(this.registration.getRegistrationId(),
initialAuthentication, servletRequest))
.willReturn(authorizedClient);
final ClientRequest clientRequest = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com"))
.build();
this.function.setPrincipalResolver((request) -> authentication);
this.function.filter(clientRequest, this.exchange)
.contextWrite(context(servletRequest, servletResponse, initialAuthentication))
.block();
List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1);
ClientRequest request = requests.get(0);
assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
assertThat(request.url().toASCIIString()).isEqualTo("https://example.com");
assertThat(request.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request)).isEmpty();
verify(this.authorizedClientRepository).loadAuthorizedClient(this.registration.getRegistrationId(),
authentication, servletRequest);
}

@Test
public void filterWhenUnauthorizedThenInvokeFailureHandler() {
assertHttpStatusInvokesFailureHandler(HttpStatus.UNAUTHORIZED, OAuth2ErrorCodes.INVALID_TOKEN);
Expand Down
Loading