2525import java .util .Set ;
2626import java .util .concurrent .ConcurrentHashMap ;
2727import java .util .concurrent .atomic .AtomicInteger ;
28+ import java .util .function .Consumer ;
2829
2930import org .apache .commons .logging .Log ;
3031import org .apache .commons .logging .LogFactory ;
@@ -106,9 +107,9 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
106107
107108private @ Nullable MessageHeaderInitializer headerInitializer ;
108109
109- private @ Nullable Map <String , MessageChannel > orderedHandlingMessageChannels ;
110+ private final Map <String , SessionInfo > sessions = new ConcurrentHashMap <>() ;
110111
111- private final Map < String , Principal > stompAuthentications = new ConcurrentHashMap <>() ;
112+ private boolean preserveReceiveOrder ;
112113
113114private @ Nullable Boolean immutableMessageInterceptorPresent ;
114115
@@ -201,7 +202,7 @@ public void setHeaderInitializer(@Nullable MessageHeaderInitializer headerInitia
201202 * @since 6.1
202203 */
203204public void setPreserveReceiveOrder (boolean preserveReceiveOrder ) {
204- this .orderedHandlingMessageChannels = ( preserveReceiveOrder ? new ConcurrentHashMap <>() : null ) ;
205+ this .preserveReceiveOrder = preserveReceiveOrder ;
205206}
206207
207208/**
@@ -210,7 +211,7 @@ public void setPreserveReceiveOrder(boolean preserveReceiveOrder) {
210211 * @since 6.1
211212 */
212213public boolean isPreserveReceiveOrder () {
213- return ( this .orderedHandlingMessageChannels != null ) ;
214+ return this .preserveReceiveOrder ;
214215}
215216
216217@ Override
@@ -245,7 +246,7 @@ public Stats getStats() {
245246 */
246247@ Override
247248public void handleMessageFromClient (WebSocketSession session ,
248- WebSocketMessage <?> webSocketMessage , MessageChannel targetChannel ) {
249+ WebSocketMessage <?> webSocketMessage , MessageChannel channel ) {
249250
250251List <Message <byte []>> messages ;
251252try {
@@ -288,35 +289,36 @@ else if (webSocketMessage instanceof BinaryMessage binaryMessage) {
288289return ;
289290}
290291
291- MessageChannel channelToUse = targetChannel ;
292- if (this .orderedHandlingMessageChannels != null ) {
293- channelToUse = this .orderedHandlingMessageChannels .computeIfAbsent (
294- session .getId (), id -> new OrderedMessageChannelDecorator (targetChannel , logger ));
295- }
292+ SessionInfo info = this .sessions .get (session .getId ());
293+ MessageChannel channelToUse = (info != null ? info .getMessageChannelToUse () : null );
296294
297295for (Message <byte []> message : messages ) {
298- StompHeaderAccessor headerAccessor =
299- MessageHeaderAccessor .getAccessor (message , StompHeaderAccessor .class );
296+ StompHeaderAccessor headerAccessor = MessageHeaderAccessor .getAccessor (message , StompHeaderAccessor .class );
300297Assert .state (headerAccessor != null , "No StompHeaderAccessor" );
301298
302299StompCommand command = headerAccessor .getCommand ();
303- boolean isConnect = StompCommand .CONNECT .equals (command ) || StompCommand .STOMP .equals (command );
304-
300+ boolean isConnect = ( StompCommand .CONNECT .equals (command ) || StompCommand .STOMP .equals (command ) );
301+ String sessionId = session . getId ();
305302boolean sent = false ;
303+
306304try {
305+ if (isConnect ) {
306+ channelToUse = (this .preserveReceiveOrder ? new OrderedMessageChannelDecorator (channel , logger ) : channel );
307+ info = new SessionInfo (channelToUse , session .getPrincipal ());
308+ SessionInfo prevInfo = this .sessions .putIfAbsent (sessionId , info );
309+ Assert .state (prevInfo == null , "Session already exists" );
310+ headerAccessor .setUserChangeCallback (info );
311+ }
312+ else {
313+ Assert .state (channelToUse != null , "Unknown session: " + sessionId );
314+ }
307315
308- headerAccessor .setSessionId (session . getId () );
316+ headerAccessor .setSessionId (sessionId );
309317headerAccessor .setSessionAttributes (session .getAttributes ());
310318headerAccessor .setUser (getUser (session ));
311- if (isConnect ) {
312- headerAccessor .setUserChangeCallback (user -> {
313- if (user != null && user != session .getPrincipal ()) {
314- this .stompAuthentications .put (session .getId (), user );
315- }
316- });
317- }
318319headerAccessor .setHeader (SimpMessageHeaderAccessor .HEART_BEAT_HEADER , headerAccessor .getHeartbeat ());
319- if (!detectImmutableMessageInterceptor (targetChannel )) {
320+
321+ if (!detectImmutableMessageInterceptor (channel )) {
320322headerAccessor .setImmutable ();
321323}
322324
@@ -356,23 +358,28 @@ else if (StompCommand.UNSUBSCRIBE.equals(command)) {
356358}
357359catch (Throwable ex ) {
358360if (logger .isDebugEnabled ()) {
359- logger .debug ("Failed to send message to MessageChannel in session " + session . getId () , ex );
361+ logger .debug ("Failed to send message to MessageChannel in session " + sessionId , ex );
360362}
361363else if (logger .isErrorEnabled ()) {
362364// Skip for unsent CONNECT or SUBSCRIBE (likely authentication/authorization issues)
363365if (sent || !(isConnect || StompCommand .SUBSCRIBE .equals (command ))) {
364366logger .error ("Failed to send message to MessageChannel in session " +
365- session . getId () + ":" + ex .getMessage ());
367+ sessionId + ":" + ex .getMessage ());
366368}
367369}
368370handleError (session , ex , message );
369371}
372+
373+ if (!sent && isConnect ) {
374+ this .sessions .remove (sessionId );
375+ break ;
376+ }
370377}
371378}
372379
373380private @ Nullable Principal getUser (WebSocketSession session ) {
374- Principal user = this .stompAuthentications .get (session .getId ());
375- return (user != null ? user : session .getPrincipal ());
381+ SessionInfo info = this .sessions .get (session .getId ());
382+ return (info != null ? info . getUser () : session .getPrincipal ());
376383}
377384
378385private void handleError (WebSocketSession session , Throwable ex , @ Nullable Message <byte []> clientMessage ) {
@@ -674,10 +681,7 @@ public void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus,
674681outputChannel .send (message );
675682}
676683finally {
677- if (this .orderedHandlingMessageChannels != null ) {
678- this .orderedHandlingMessageChannels .remove (session .getId ());
679- }
680- this .stompAuthentications .remove (session .getId ());
684+ this .sessions .remove (session .getId ());
681685SimpAttributesContextHolder .resetAttributes ();
682686simpAttributes .sessionCompleted ();
683687}
@@ -707,6 +711,36 @@ public String toString() {
707711}
708712
709713
714+ private static class SessionInfo implements Consumer <Principal > {
715+
716+ private final MessageChannel channel ;
717+
718+ private final @ Nullable Principal webSocketUser ;
719+
720+ private volatile @ Nullable Principal stompUser ;
721+
722+ SessionInfo (MessageChannel channel , @ Nullable Principal user ) {
723+ this .channel = channel ;
724+ this .webSocketUser = user ;
725+ }
726+
727+ public MessageChannel getMessageChannelToUse () {
728+ return this .channel ;
729+ }
730+
731+ public @ Nullable Principal getUser () {
732+ return (this .stompUser != null ? this .stompUser : this .webSocketUser );
733+ }
734+
735+ @ Override
736+ public void accept (@ Nullable Principal stompUser ) {
737+ if (stompUser != null && stompUser != this .webSocketUser ) {
738+ this .stompUser = stompUser ;
739+ }
740+ }
741+ }
742+
743+
710744/**
711745 * Contract for access to session counters.
712746 * @since 5.2
0 commit comments