Skip to content

Commit 2b9510d

Browse files
lachlan-robertspoutsma
authored andcommitted
Prevent possible ReadPendingException from multiple demand.
Signed-off-by: Lachlan Roberts <lachlan@webtide.com>
1 parent 8cc8b11 commit 2b9510d

File tree

2 files changed

+72
-37
lines changed

2 files changed

+72
-37
lines changed

spring-webflux/src/main/java/org/springframework/web/reactive/socket/adapter/JettyWebSocketHandlerAdapter.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,22 +95,22 @@ public void onWebSocketText(String message) {
9595
byte[] bytes = message.getBytes(StandardCharsets.UTF_8);
9696
DataBuffer buffer = this.delegateSession.bufferFactory().wrap(bytes);
9797
WebSocketMessage webSocketMessage = new WebSocketMessage(Type.TEXT, buffer);
98-
this.delegateSession.handleMessage(webSocketMessage.getType(), webSocketMessage);
98+
this.delegateSession.handleMessage(webSocketMessage);
9999
}
100100

101101
@Override
102102
public void onWebSocketBinary(ByteBuffer byteBuffer, Callback callback) {
103103
DataBuffer buffer = this.delegateSession.bufferFactory().wrap(byteBuffer);
104104
buffer = new JettyDataBuffer(buffer, callback);
105105
WebSocketMessage webSocketMessage = new WebSocketMessage(Type.BINARY, buffer);
106-
this.delegateSession.handleMessage(webSocketMessage.getType(), webSocketMessage);
106+
this.delegateSession.handleMessage(webSocketMessage);
107107
}
108108

109109
@Override
110110
public void onWebSocketPong(ByteBuffer payload) {
111111
DataBuffer buffer = this.delegateSession.bufferFactory().wrap(BufferUtil.copy(payload));
112112
WebSocketMessage webSocketMessage = new WebSocketMessage(Type.PONG, buffer);
113-
this.delegateSession.handleMessage(webSocketMessage.getType(), webSocketMessage);
113+
this.delegateSession.handleMessage(webSocketMessage);
114114
}
115115

116116
@Override

spring-webflux/src/main/java/org/springframework/web/reactive/socket/adapter/JettyWebSocketSession.java

Lines changed: 69 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818

1919
import java.nio.ByteBuffer;
2020
import java.nio.charset.StandardCharsets;
21-
import java.util.concurrent.atomic.AtomicLong;
21+
import java.util.concurrent.locks.Lock;
22+
import java.util.concurrent.locks.ReentrantLock;
2223

24+
import org.eclipse.jetty.util.BufferUtil;
2325
import org.eclipse.jetty.websocket.api.Callback;
2426
import org.eclipse.jetty.websocket.api.Session;
2527
import org.reactivestreams.Publisher;
@@ -48,10 +50,12 @@
4850
public class JettyWebSocketSession extends AbstractWebSocketSession<Session> {
4951

5052
private final Flux<WebSocketMessage> flux;
51-
private final AtomicLong requested = new AtomicLong(0);
52-
5353
private final Sinks.One<CloseStatus> closeStatusSink = Sinks.one();
54-
@Nullable
54+
private final Lock lock = new ReentrantLock();
55+
private long requested = 0;
56+
private boolean awaitingDemand = false;
57+
58+
@SuppressWarnings("NotNullFieldNotInitialized")
5559
private FluxSink<WebSocketMessage> sink;
5660

5761
@Nullable
@@ -70,15 +74,49 @@ public JettyWebSocketSession(Session session, HandshakeInfo info, DataBufferFact
7074
this.sink = emitter;
7175
emitter.onRequest(n ->
7276
{
73-
requested.addAndGet(n);
74-
tryDemand();
77+
boolean demand = false;
78+
lock.lock();
79+
try
80+
{
81+
requested += n;
82+
if (!awaitingDemand && requested > 0) {
83+
requested--;
84+
awaitingDemand = true;
85+
demand = true;
86+
}
87+
}
88+
finally {
89+
lock.unlock();
90+
}
91+
92+
if (demand)
93+
getDelegate().demand();
7594
});
7695
});
7796
}
7897

79-
void handleMessage(WebSocketMessage.Type type, WebSocketMessage message) {
98+
void handleMessage(WebSocketMessage message) {
8099
this.sink.next(message);
81-
tryDemand();
100+
101+
boolean demand = false;
102+
lock.lock();
103+
try
104+
{
105+
if (!awaitingDemand)
106+
throw new IllegalStateException();
107+
awaitingDemand = false;
108+
if (requested > 0) {
109+
requested--;
110+
awaitingDemand = true;
111+
demand = true;
112+
}
113+
}
114+
finally {
115+
lock.unlock();
116+
}
117+
118+
if (demand)
119+
getDelegate().demand();
82120
}
83121

84122
void handleError(Throwable ex) {
@@ -127,23 +165,6 @@ public Flux<WebSocketMessage> receive() {
127165
return flux;
128166
}
129167

130-
private void tryDemand()
131-
{
132-
while (true)
133-
{
134-
long r = requested.get();
135-
if (r == 0)
136-
return;
137-
138-
// TODO: protect against readpending from multiple demand.
139-
if (requested.compareAndSet(r, r - 1))
140-
{
141-
getDelegate().demand();
142-
return;
143-
}
144-
}
145-
}
146-
147168
@Override
148169
public Mono<Void> send(Publisher<WebSocketMessage> messages) {
149170
return Flux.from(messages)
@@ -162,17 +183,31 @@ protected Mono<Void> sendMessage(WebSocketMessage message) {
162183
session.sendText(text, completable);
163184
}
164185
else {
165-
// TODO: Ping and Pong message should combine payload into single buffer?
166-
try (DataBuffer.ByteBufferIterator iterator = dataBuffer.readableByteBuffers()) {
167-
while (iterator.hasNext()) {
168-
ByteBuffer byteBuffer = iterator.next();
169-
switch (message.getType()) {
170-
case BINARY -> session.sendBinary(byteBuffer, completable);
171-
case PING -> session.sendPing(byteBuffer, completable);
172-
case PONG -> session.sendPong(byteBuffer, completable);
173-
default -> throw new IllegalArgumentException("Unexpected message type: " + message.getType());
186+
switch (message.getType()) {
187+
case BINARY ->
188+
{
189+
try (DataBuffer.ByteBufferIterator iterator = dataBuffer.readableByteBuffers()) {
190+
while (iterator.hasNext()) {
191+
ByteBuffer byteBuffer = iterator.next();
192+
session.sendBinary(byteBuffer, completable);
193+
}
174194
}
175195
}
196+
case PING ->
197+
{
198+
// Maximum size of Control frame payload is 125, per RFC 6455.
199+
ByteBuffer buffer = BufferUtil.allocate(125);
200+
dataBuffer.toByteBuffer(buffer);
201+
session.sendPing(buffer, completable);
202+
}
203+
case PONG ->
204+
{
205+
// Maximum size of Control frame payload is 125, per RFC 6455.
206+
ByteBuffer buffer = BufferUtil.allocate(125);
207+
dataBuffer.toByteBuffer(buffer);
208+
session.sendPong(buffer, completable);
209+
}
210+
default -> throw new IllegalArgumentException("Unexpected message type: " + message.getType());
176211
}
177212
}
178213
return Mono.fromFuture(completable);

0 commit comments

Comments
 (0)