Skip to content

Commit f15c12a

Browse files
committed
Merge branch '6.2.x'
2 parents 3607f98 + c88bfc5 commit f15c12a

File tree

3 files changed

+77
-33
lines changed

3 files changed

+77
-33
lines changed

spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java

Lines changed: 65 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.util.Set;
2626
import java.util.concurrent.ConcurrentHashMap;
2727
import java.util.concurrent.atomic.AtomicInteger;
28+
import java.util.function.Consumer;
2829

2930
import org.apache.commons.logging.Log;
3031
import org.apache.commons.logging.LogFactory;
@@ -106,9 +107,9 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
106107

107108
private @Nullable MessageHeaderInitializer headerInitializer;
108109

109-
private @Nullable Map<String, MessageChannel> orderedHandlingMessageChannels;
110+
private final Map<String, SessionInfo> sessions = new ConcurrentHashMap<>();
110111

111-
private final Map<String, Principal> stompAuthentications = new ConcurrentHashMap<>();
112+
private boolean preserveReceiveOrder;
112113

113114
private @Nullable Boolean immutableMessageInterceptorPresent;
114115

@@ -201,7 +202,7 @@ public void setHeaderInitializer(@Nullable MessageHeaderInitializer headerInitia
201202
* @since 6.1
202203
*/
203204
public void setPreserveReceiveOrder(boolean preserveReceiveOrder) {
204-
this.orderedHandlingMessageChannels = (preserveReceiveOrder ? new ConcurrentHashMap<>() : null);
205+
this.preserveReceiveOrder = preserveReceiveOrder;
205206
}
206207

207208
/**
@@ -210,7 +211,7 @@ public void setPreserveReceiveOrder(boolean preserveReceiveOrder) {
210211
* @since 6.1
211212
*/
212213
public boolean isPreserveReceiveOrder() {
213-
return (this.orderedHandlingMessageChannels != null);
214+
return this.preserveReceiveOrder;
214215
}
215216

216217
@Override
@@ -245,7 +246,7 @@ public Stats getStats() {
245246
*/
246247
@Override
247248
public void handleMessageFromClient(WebSocketSession session,
248-
WebSocketMessage<?> webSocketMessage, MessageChannel targetChannel) {
249+
WebSocketMessage<?> webSocketMessage, MessageChannel channel) {
249250

250251
List<Message<byte[]>> messages;
251252
try {
@@ -288,35 +289,36 @@ else if (webSocketMessage instanceof BinaryMessage binaryMessage) {
288289
return;
289290
}
290291

291-
MessageChannel channelToUse = targetChannel;
292-
if (this.orderedHandlingMessageChannels != null) {
293-
channelToUse = this.orderedHandlingMessageChannels.computeIfAbsent(
294-
session.getId(), id -> new OrderedMessageChannelDecorator(targetChannel, logger));
295-
}
292+
SessionInfo info = this.sessions.get(session.getId());
293+
MessageChannel channelToUse = (info != null ? info.getMessageChannelToUse() : null);
296294

297295
for (Message<byte[]> message : messages) {
298-
StompHeaderAccessor headerAccessor =
299-
MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
296+
StompHeaderAccessor headerAccessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
300297
Assert.state(headerAccessor != null, "No StompHeaderAccessor");
301298

302299
StompCommand command = headerAccessor.getCommand();
303-
boolean isConnect = StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command);
304-
300+
boolean isConnect = (StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command));
301+
String sessionId = session.getId();
305302
boolean sent = false;
303+
306304
try {
305+
if (isConnect) {
306+
channelToUse = (this.preserveReceiveOrder ? new OrderedMessageChannelDecorator(channel, logger) : channel);
307+
info = new SessionInfo(channelToUse, session.getPrincipal());
308+
SessionInfo prevInfo = this.sessions.putIfAbsent(sessionId, info);
309+
Assert.state(prevInfo == null, "Session already exists");
310+
headerAccessor.setUserChangeCallback(info);
311+
}
312+
else {
313+
Assert.state(channelToUse != null, "Unknown session: " + sessionId);
314+
}
307315

308-
headerAccessor.setSessionId(session.getId());
316+
headerAccessor.setSessionId(sessionId);
309317
headerAccessor.setSessionAttributes(session.getAttributes());
310318
headerAccessor.setUser(getUser(session));
311-
if (isConnect) {
312-
headerAccessor.setUserChangeCallback(user -> {
313-
if (user != null && user != session.getPrincipal()) {
314-
this.stompAuthentications.put(session.getId(), user);
315-
}
316-
});
317-
}
318319
headerAccessor.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, headerAccessor.getHeartbeat());
319-
if (!detectImmutableMessageInterceptor(targetChannel)) {
320+
321+
if (!detectImmutableMessageInterceptor(channel)) {
320322
headerAccessor.setImmutable();
321323
}
322324

@@ -356,23 +358,28 @@ else if (StompCommand.UNSUBSCRIBE.equals(command)) {
356358
}
357359
catch (Throwable ex) {
358360
if (logger.isDebugEnabled()) {
359-
logger.debug("Failed to send message to MessageChannel in session " + session.getId(), ex);
361+
logger.debug("Failed to send message to MessageChannel in session " + sessionId, ex);
360362
}
361363
else if (logger.isErrorEnabled()) {
362364
// Skip for unsent CONNECT or SUBSCRIBE (likely authentication/authorization issues)
363365
if (sent || !(isConnect || StompCommand.SUBSCRIBE.equals(command))) {
364366
logger.error("Failed to send message to MessageChannel in session " +
365-
session.getId() + ":" + ex.getMessage());
367+
sessionId + ":" + ex.getMessage());
366368
}
367369
}
368370
handleError(session, ex, message);
369371
}
372+
373+
if (!sent && isConnect) {
374+
this.sessions.remove(sessionId);
375+
break;
376+
}
370377
}
371378
}
372379

373380
private @Nullable Principal getUser(WebSocketSession session) {
374-
Principal user = this.stompAuthentications.get(session.getId());
375-
return (user != null ? user : session.getPrincipal());
381+
SessionInfo info = this.sessions.get(session.getId());
382+
return (info != null ? info.getUser() : session.getPrincipal());
376383
}
377384

378385
private void handleError(WebSocketSession session, Throwable ex, @Nullable Message<byte[]> clientMessage) {
@@ -674,10 +681,7 @@ public void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus,
674681
outputChannel.send(message);
675682
}
676683
finally {
677-
if (this.orderedHandlingMessageChannels != null) {
678-
this.orderedHandlingMessageChannels.remove(session.getId());
679-
}
680-
this.stompAuthentications.remove(session.getId());
684+
this.sessions.remove(session.getId());
681685
SimpAttributesContextHolder.resetAttributes();
682686
simpAttributes.sessionCompleted();
683687
}
@@ -707,6 +711,36 @@ public String toString() {
707711
}
708712

709713

714+
private static class SessionInfo implements Consumer<Principal> {
715+
716+
private final MessageChannel channel;
717+
718+
private final @Nullable Principal webSocketUser;
719+
720+
private volatile @Nullable Principal stompUser;
721+
722+
SessionInfo(MessageChannel channel, @Nullable Principal user) {
723+
this.channel = channel;
724+
this.webSocketUser = user;
725+
}
726+
727+
public MessageChannel getMessageChannelToUse() {
728+
return this.channel;
729+
}
730+
731+
public @Nullable Principal getUser() {
732+
return (this.stompUser != null ? this.stompUser : this.webSocketUser);
733+
}
734+
735+
@Override
736+
public void accept(@Nullable Principal stompUser) {
737+
if (stompUser != null && stompUser != this.webSocketUser) {
738+
this.stompUser = stompUser;
739+
}
740+
}
741+
}
742+
743+
710744
/**
711745
* Contract for access to session counters.
712746
* @since 5.2

spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,22 @@ void clientInboundChannelSendMessage() throws Exception {
101101
session.setOpen(true);
102102
webSocketHandler.afterConnectionEstablished(session);
103103

104+
webSocketHandler.handleMessage(session,
105+
StompTextMessageBuilder.create(StompCommand.CONNECT).headers("destination:/foo").build());
106+
104107
webSocketHandler.handleMessage(session,
105108
StompTextMessageBuilder.create(StompCommand.SEND).headers("destination:/foo").build());
106109

107110
Message<?> message = channel.messages.get(0);
108111
StompHeaderAccessor accessor = StompHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
109112
assertThat(accessor).isNotNull();
110113
assertThat(accessor.isMutable()).isFalse();
114+
assertThat(accessor.getMessageType()).isEqualTo(SimpMessageType.CONNECT);
115+
116+
message = channel.messages.get(1);
117+
accessor = StompHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
118+
assertThat(accessor).isNotNull();
119+
assertThat(accessor.isMutable()).isFalse();
111120
assertThat(accessor.getMessageType()).isEqualTo(SimpMessageType.MESSAGE);
112121
assertThat(accessor.getDestination()).isEqualTo("/foo");
113122
}

spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,10 @@ void sendMessageToController(
8989

9090
super.setup(server, webSocketClient, testInfo);
9191

92-
TextMessage message = create(StompCommand.SEND).headers("destination:/app/simple").build();
92+
TextMessage m1 = create(StompCommand.CONNECT).headers("accept-version:1.1").build();
93+
TextMessage m2 = create(StompCommand.SEND).headers("destination:/app/simple").build();
9394

94-
try (WebSocketSession session = execute(new TestClientWebSocketHandler(0, message), "/ws").get()) {
95+
try (WebSocketSession session = execute(new TestClientWebSocketHandler(0, m1, m2), "/ws").get()) {
9596
assertThat(session).isNotNull();
9697
SimpleController controller = this.wac.getBean(SimpleController.class);
9798
assertThat(controller.latch.await(TIMEOUT, TimeUnit.SECONDS)).isTrue();

0 commit comments

Comments
 (0)