Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,8 @@ public final void onOutput() throws HttpException, IOException {
}

if (connState.compareTo(ConnectionHandshake.ACTIVE) <= 0 && remoteSettingState == SettingsHandshake.ACKED) {
while (streamMap.size() < remoteConfig.getMaxConcurrentStreams()) {
final int outboundLimit = remoteConfig.getMaxConcurrentStreams();
while (countActiveLocalInitiated() < outboundLimit) {
final Command command = ioSession.poll();
if (command == null) {
break;
Expand Down Expand Up @@ -771,6 +772,8 @@ private void consumeFrame(final RawFrame frame) throws HttpException, IOExceptio
if (goAwayReceived ) {
throw new H2ConnectionException(H2Error.PROTOCOL_ERROR, "GOAWAY received");
}
final int inboundLimit = localConfig.getMaxConcurrentStreams();
final boolean refuseInbound = countActiveRemoteInitiated() >= inboundLimit;

updateLastStreamId(streamId);

Expand All @@ -785,10 +788,31 @@ private void consumeFrame(final RawFrame frame) throws HttpException, IOExceptio
}

stream = new H2Stream(channel, streamHandler, true);
if (refuseInbound) {
try {
stream.localReset(new H2StreamResetException(H2Error.REFUSED_STREAM, "Inbound stream concurrency limit exceeded"));
} catch (final IOException ignore) {
}
}

if (stream.isOutputReady()) {
stream.produceOutput();
}
streamMap.put(streamId, stream);
} else {
if (stream.isRemoteInitiated() && stream.channel.idle) {
final int inboundLimit = localConfig.getMaxConcurrentStreams();
final int active = countActiveRemoteInitiated();
if (active >= inboundLimit) {
try {
stream.localReset(new H2StreamResetException(
H2Error.REFUSED_STREAM, "Inbound stream concurrency limit exceeded"));
} catch (final IOException ignore) {
}
break;
}
stream.channel.idle = false;
}
}

try {
Expand Down Expand Up @@ -973,7 +997,7 @@ private void consumeFrame(final RawFrame frame) throws HttpException, IOExceptio
updateLastStreamId(promisedStreamId);

final H2StreamChannelImpl channel = new H2StreamChannelImpl(
promisedStreamId, false, initInputWinSize, initOutputWinSize);
promisedStreamId, true, initInputWinSize, initOutputWinSize);
final H2StreamHandler streamHandler;
if (connState.compareTo(ConnectionHandshake.ACTIVE) <= 0) {
streamHandler = createRemotelyInitiatedStream(channel, httpProcessor, connMetrics,
Expand All @@ -985,7 +1009,6 @@ private void consumeFrame(final RawFrame frame) throws HttpException, IOExceptio

final H2Stream promisedStream = new H2Stream(channel, streamHandler, true);
streamMap.put(promisedStreamId, promisedStream);

try {
consumePushPromiseFrame(frame, payload, promisedStream);
} catch (final H2StreamResetException ex) {
Expand Down Expand Up @@ -1218,6 +1241,12 @@ private void produceOutput() throws HttpException, IOException {
final Map.Entry<Integer, H2Stream> entry = it.next();
final H2Stream stream = entry.getValue();
if (!stream.isLocalClosed() && stream.getOutputWindow().get() > 0) {
if (!stream.isRemoteInitiated() && stream.channel.idle) {
final int outboundLimit = remoteConfig.getMaxConcurrentStreams();
if (countActiveLocalInitiated() >= outboundLimit) {
continue;
}
}
stream.produceOutput();
}
if (stream.isTerminated()) {
Expand Down Expand Up @@ -1260,6 +1289,7 @@ private void applyRemoteSettings(final H2Config config) throws H2ConnectionExcep
}
}
}
requestSessionOutput();
}

private void applyLocalSettings() throws H2ConnectionException {
Expand Down Expand Up @@ -1395,6 +1425,29 @@ H2StreamChannelImpl createChannel(final int streamId) {
return new H2StreamChannelImpl(streamId, false, initInputWinSize, initOutputWinSize);
}

// Count active streams by initiator (directional) for concurrency limits
private int countActiveLocalInitiated() {
int n = 0;
for (final Map.Entry<Integer, H2Stream> e : streamMap.entrySet()) {
final H2Stream s = e.getValue();
if (!s.isRemoteInitiated() && !s.isTerminated() && !s.channel.idle) {
n++;
}
}
return n;
}

private int countActiveRemoteInitiated() {
int n = 0;
for (final Map.Entry<Integer, H2Stream> e : streamMap.entrySet()) {
final H2Stream s = e.getValue();
if (s.isRemoteInitiated() && !s.isTerminated() && !s.channel.idle) {
n++;
}
}
return n;
}

class H2StreamChannelImpl implements H2StreamChannel {

private final int id;
Expand Down Expand Up @@ -1557,15 +1610,13 @@ boolean localReset(final int code) throws IOException {
if (isLocalReset()) {
return false;
}
ensureNotClosed();
// Allow RST_STREAM from any state, including reserved (idle) and even if local end is already closed.
localEndStream = true;
deadline = System.currentTimeMillis() + LINGER_TIME;
if (!idle) {
final RawFrame resetStream = frameFactory.createResetStream(id, code);
commitFrameInternal(resetStream);
return true;
}
return false;

final RawFrame resetStream = frameFactory.createResetStream(id, code);
commitFrameInternal(resetStream);
return true;
} finally {
ioSession.getLock().unlock();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.locks.Lock;

Expand Down Expand Up @@ -88,11 +89,14 @@ class TestAbstractH2StreamMultiplexer {
ArgumentCaptor<List<Header>> headersCaptor;
@Captor
ArgumentCaptor<Exception> exceptionCaptor;
@Captor
ArgumentCaptor<RawFrame> frameCaptor;

@BeforeEach
void prepareMocks() {
MockitoAnnotations.openMocks(this);
Mockito.when(protocolIOSession.getLock()).thenReturn(lock);
Mockito.when(protocolIOSession.poll()).thenReturn(null);
}

static class H2StreamMultiplexerImpl extends AbstractH2StreamMultiplexer {
Expand Down Expand Up @@ -666,7 +670,7 @@ void testInputHeaderContinuationFramesMaxLimit() throws Exception {
outBuffer.write(continuationFrame3, writableChannel);

Assertions.assertThrows(H2ConnectionException.class, () ->
streamMultiplexer.onInput(ByteBuffer.wrap(writableChannel.toByteArray())));
streamMultiplexer.onInput(ByteBuffer.wrap(writableChannel.toByteArray())));
}

@Test
Expand Down Expand Up @@ -767,5 +771,80 @@ void testStreamRemoteResetNoErrorRemoteAlreadyClosed() throws Exception {
Mockito.verify(streamHandler, Mockito.never()).failed(ArgumentMatchers.any());
}

}
@Test
void testInboundPushPromise_thenRefuseOnFirstHeadersWhenOverLimit() throws Exception {
final H2Config h2Config = H2Config.custom()
.setMaxConcurrentStreams(1) // allow only one active remote stream
.build();

final AbstractH2StreamMultiplexer mux = new H2StreamMultiplexerImpl(
protocolIOSession,
FRAME_FACTORY,
StreamIdGenerator.ODD,
httpProcessor,
CharCodingConfig.DEFAULT,
h2Config,
h2StreamListener,
() -> streamHandler);

final WritableByteChannelMock chan = new WritableByteChannelMock(2048);
final FrameOutputBuffer out = new FrameOutputBuffer(16 * 1024);

// Keep one remote stream (2) active
final ByteArrayBuffer respHdr = new ByteArrayBuffer(64);
final HPackEncoder enc = new HPackEncoder(H2Config.INIT.getHeaderTableSize(),
CharCodingSupport.createEncoder(CharCodingConfig.DEFAULT));
enc.encodeHeaders(respHdr,
Collections.singletonList(new BasicHeader(":status", "200")),
h2Config.isCompressionEnabled());
out.write(FRAME_FACTORY.createHeaders(
2, ByteBuffer.wrap(respHdr.array(), 0, respHdr.length()), true, false), chan);
mux.onInput(ByteBuffer.wrap(chan.toByteArray()));

// Send PUSH_PROMISE (do NOT verify outputs here)
chan.reset();
final int promisedId = 4;
final ByteArrayBuffer promiseBuf = new ByteArrayBuffer(128);
promiseBuf.append((byte) (promisedId >> 24));
promiseBuf.append((byte) (promisedId >> 16));
promiseBuf.append((byte) (promisedId >> 8));
promiseBuf.append((byte) promisedId);
enc.encodeHeaders(promiseBuf, Arrays.asList(
new BasicHeader(":method", "GET"),
new BasicHeader(":scheme", "https"),
new BasicHeader(":authority", "example.org"),
new BasicHeader(":path", "/pushed")
), h2Config.isCompressionEnabled());
out.write(FRAME_FACTORY.createPushPromise(
2, ByteBuffer.wrap(promiseBuf.array(), 0, promiseBuf.length()), true), chan);
mux.onInput(ByteBuffer.wrap(chan.toByteArray()));

chan.reset();
final ByteArrayBuffer pushedResp = new ByteArrayBuffer(32);
enc.encodeHeaders(pushedResp,
Collections.singletonList(new BasicHeader(":status", "200")),
h2Config.isCompressionEnabled());
out.write(FRAME_FACTORY.createHeaders(
4, ByteBuffer.wrap(pushedResp.array(), 0, pushedResp.length()), true, false), chan);
mux.onInput(ByteBuffer.wrap(chan.toByteArray()));

// Flush any queued outbound frames so listener sees them
mux.onOutput();

Mockito.verify(h2StreamListener, Mockito.atLeastOnce())
.onFrameOutput(ArgumentMatchers.same(mux), ArgumentMatchers.anyInt(), frameCaptor.capture());

boolean refused = false;
for (final RawFrame f : frameCaptor.getAllValues()) {
if (f.getType() == FrameType.RST_STREAM.getValue() && f.getStreamId() == promisedId) {
final int code = f.getPayload().duplicate().getInt();
Assertions.assertEquals(H2Error.REFUSED_STREAM.getCode(), code);
refused = true;
break;
}
}
Assertions.assertTrue(refused, "Expected RST_STREAM(REFUSED_STREAM) on first HEADERS for promised stream 4");
}


}