Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package org.springframework.security.core.session;

/**
* Strategy for determining whether two principals represent the same identity.
*
* @since 7.0
*/
@FunctionalInterface
public interface PrincipalIdentifierStrategy {

/**
* Returns true if the two principals should be treated as the same logical user.
*/
boolean matches(Object existingPrincipal, Object incomingPrincipal);

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,6 @@
* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.core.session;
Expand All @@ -35,18 +25,11 @@
import org.springframework.util.Assert;

/**
* Default implementation of
* {@link org.springframework.security.core.session.SessionRegistry SessionRegistry} which
* listens for {@link org.springframework.security.core.session.SessionDestroyedEvent
* SessionDestroyedEvent}s published in the Spring application context.
* <p>
* For this class to function correctly in a web application, it is important that you
* register an <a href="
* {@docRoot}/org/springframework/security/web/session/HttpSessionEventPublisher.html">HttpSessionEventPublisher</a> in
* the <tt>web.xml</tt> file so that this class is notified of sessions that expire.
* Default implementation of {@link SessionRegistry}.
*
* Now supports pluggable Principal identity matching strategy.
*
* @author Ben Alex
* @author Luke Taylor
* @since 7.0
*/
public class SessionRegistryImpl implements SessionRegistry, ApplicationListener<AbstractSessionEvent> {

Expand All @@ -55,18 +38,36 @@ public class SessionRegistryImpl implements SessionRegistry, ApplicationListener
// <principal:Object,SessionIdSet>
private final ConcurrentMap<Object, Set<String>> principals;

// <sessionId:Object,SessionInformation>
// <sessionId:String,SessionInformation>
private final Map<String, SessionInformation> sessionIds;

private final PrincipalIdentifierStrategy principalIdentifierStrategy;

/**
* Default constructor (backward compatible).
* Uses equals() for principal matching.
*/
public SessionRegistryImpl() {
this((existing, incoming) -> existing.equals(incoming));
}

/**
* Constructor allowing custom principal matching strategy.
*/
public SessionRegistryImpl(PrincipalIdentifierStrategy strategy) {
this.principals = new ConcurrentHashMap<>();
this.sessionIds = new ConcurrentHashMap<>();
this.principalIdentifierStrategy = strategy;
}

/**
* Secondary constructor for testing/custom maps.
*/
public SessionRegistryImpl(ConcurrentMap<Object, Set<String>> principals,
Map<String, SessionInformation> sessionIds) {
this.principals = principals;
this.sessionIds = sessionIds;
this.principalIdentifierStrategy = (existing, incoming) -> existing.equals(incoming);
}

@Override
Expand All @@ -76,11 +77,23 @@ public List<Object> getAllPrincipals() {

@Override
public List<SessionInformation> getAllSessions(Object principal, boolean includeExpiredSessions) {
Set<String> sessionsUsedByPrincipal = this.principals.get(principal);

Set<String> sessionsUsedByPrincipal = null;

// 🔥 Strategy-based lookup
for (Map.Entry<Object, Set<String>> entry : this.principals.entrySet()) {
if (this.principalIdentifierStrategy.matches(entry.getKey(), principal)) {
sessionsUsedByPrincipal = entry.getValue();
break;
}
}

if (sessionsUsedByPrincipal == null) {
return Collections.emptyList();
}

List<SessionInformation> list = new ArrayList<>(sessionsUsedByPrincipal.size());

for (String sessionId : sessionsUsedByPrincipal) {
SessionInformation sessionInformation = getSessionInformation(sessionId);
if (sessionInformation == null) {
Expand All @@ -90,6 +103,7 @@ public List<SessionInformation> getAllSessions(Object principal, boolean include
list.add(sessionInformation);
}
}

return list;
}

Expand All @@ -101,12 +115,14 @@ public List<SessionInformation> getAllSessions(Object principal, boolean include

@Override
public void onApplicationEvent(AbstractSessionEvent event) {

if (event instanceof SessionDestroyedEvent sessionDestroyedEvent) {
String sessionId = sessionDestroyedEvent.getId();
removeSessionInformation(sessionId);
}
else if (event instanceof SessionIdChangedEvent sessionIdChangedEvent) {
String oldSessionId = sessionIdChangedEvent.getOldSessionId();

if (this.sessionIds.containsKey(oldSessionId)) {
Object principal = this.sessionIds.get(oldSessionId).getPrincipal();
removeSessionInformation(oldSessionId);
Expand All @@ -126,49 +142,66 @@ public void refreshLastRequest(String sessionId) {

@Override
public void registerNewSession(String sessionId, Object principal) {

Assert.hasText(sessionId, "SessionId required as per interface contract");
Assert.notNull(principal, "Principal required as per interface contract");

if (getSessionInformation(sessionId) != null) {
removeSessionInformation(sessionId);
}

if (this.logger.isDebugEnabled()) {
this.logger.debug(LogMessage.format("Registering session %s, for principal %s", sessionId, principal));
this.logger.debug(
LogMessage.format("Registering session %s, for principal %s", sessionId, principal));
}

this.sessionIds.put(sessionId, new SessionInformation(principal, sessionId, new Date()));

this.principals.compute(principal, (key, sessionsUsedByPrincipal) -> {
if (sessionsUsedByPrincipal == null) {
sessionsUsedByPrincipal = new CopyOnWriteArraySet<>();
}
sessionsUsedByPrincipal.add(sessionId);
this.logger.trace(LogMessage.format("Sessions used by '%s' : %s", principal, sessionsUsedByPrincipal));
this.logger.trace(
LogMessage.format("Sessions used by '%s' : %s", principal, sessionsUsedByPrincipal));
return sessionsUsedByPrincipal;
});
}

@Override
public void removeSessionInformation(String sessionId) {

Assert.hasText(sessionId, "SessionId required as per interface contract");

SessionInformation info = getSessionInformation(sessionId);

if (info == null) {
return;
}

if (this.logger.isTraceEnabled()) {
this.logger.debug("Removing session " + sessionId + " from set of registered sessions");
}

this.sessionIds.remove(sessionId);

this.principals.computeIfPresent(info.getPrincipal(), (key, sessionsUsedByPrincipal) -> {
this.logger
.debug(LogMessage.format("Removing session %s from principal's set of registered sessions", sessionId));

this.logger.debug(LogMessage.format(
"Removing session %s from principal's set of registered sessions", sessionId));

sessionsUsedByPrincipal.remove(sessionId);

if (sessionsUsedByPrincipal.isEmpty()) {
// No need to keep object in principals Map anymore
this.logger.debug(LogMessage.format("Removing principal %s from registry", info.getPrincipal()));
sessionsUsedByPrincipal = null;
this.logger.debug(
LogMessage.format("Removing principal %s from registry", info.getPrincipal()));
return null;
}
this.logger
.trace(LogMessage.format("Sessions used by '%s' : %s", info.getPrincipal(), sessionsUsedByPrincipal));

this.logger.trace(
LogMessage.format("Sessions used by '%s' : %s", info.getPrincipal(), sessionsUsedByPrincipal));

return sessionsUsedByPrincipal;
});
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.springframework.security.core.session;

import java.security.Principal;
import java.util.Date;
import java.util.List;

Expand Down Expand Up @@ -192,6 +193,42 @@ public String getNewSessionId() {
assertThat(this.sessionRegistry.getSessionInformation(newSessionId)).isNull();
}

@Test
public void principalsWithSameNameButDifferentInstancesAreTreatedAsDifferent() {
SessionRegistryImpl registry = new SessionRegistryImpl();

Principal principal1 = () -> "user";
Principal principal2 = () -> "user"; // Different instance, same name

String sessionId = "session-1";

registry.registerNewSession(sessionId, principal1);

// Default behavior: should NOT find session for different instance
assertThat(registry.getAllSessions(principal2, false)).isEmpty();
}

@Test
public void customPrincipalIdentifierStrategyMatchesPrincipalsByName() {
PrincipalIdentifierStrategy strategy =
(existing, incoming) ->
existing instanceof Principal e &&
incoming instanceof Principal i &&
e.getName().equals(i.getName());

SessionRegistryImpl registry = new SessionRegistryImpl(strategy);

Principal principal1 = () -> "user";
Principal principal2 = () -> "user"; // Different instance

String sessionId = "session-1";

registry.registerNewSession(sessionId, principal1);

assertThat(registry.getAllSessions(principal2, false)).hasSize(1);
}


private boolean contains(String sessionId, Object principal) {
List<SessionInformation> info = this.sessionRegistry.getAllSessions(principal, false);
for (SessionInformation sessionInformation : info) {
Expand Down