diff --git a/httpcore5-h2/src/main/java/org/apache/hc/core5/http2/impl/nio/AbstractH2StreamMultiplexer.java b/httpcore5-h2/src/main/java/org/apache/hc/core5/http2/impl/nio/AbstractH2StreamMultiplexer.java index 6241eb4f63..a0d0cf15b0 100644 --- a/httpcore5-h2/src/main/java/org/apache/hc/core5/http2/impl/nio/AbstractH2StreamMultiplexer.java +++ b/httpcore5-h2/src/main/java/org/apache/hc/core5/http2/impl/nio/AbstractH2StreamMultiplexer.java @@ -528,7 +528,8 @@ public final void onOutput() throws HttpException, IOException { } if (connState.compareTo(ConnectionHandshake.ACTIVE) <= 0 && remoteSettingState == SettingsHandshake.ACKED) { - while (streamMap.size() < remoteConfig.getMaxConcurrentStreams()) { + final int outboundLimit = remoteConfig.getMaxConcurrentStreams(); + while (countActiveLocalInitiated() < outboundLimit) { final Command command = ioSession.poll(); if (command == null) { break; @@ -771,6 +772,8 @@ private void consumeFrame(final RawFrame frame) throws HttpException, IOExceptio if (goAwayReceived ) { throw new H2ConnectionException(H2Error.PROTOCOL_ERROR, "GOAWAY received"); } + final int inboundLimit = localConfig.getMaxConcurrentStreams(); + final boolean refuseInbound = countActiveRemoteInitiated() >= inboundLimit; updateLastStreamId(streamId); @@ -785,10 +788,31 @@ private void consumeFrame(final RawFrame frame) throws HttpException, IOExceptio } stream = new H2Stream(channel, streamHandler, true); + if (refuseInbound) { + try { + stream.localReset(new H2StreamResetException(H2Error.REFUSED_STREAM, "Inbound stream concurrency limit exceeded")); + } catch (final IOException ignore) { + } + } + if (stream.isOutputReady()) { stream.produceOutput(); } streamMap.put(streamId, stream); + } else { + if (stream.isRemoteInitiated() && stream.channel.idle) { + final int inboundLimit = localConfig.getMaxConcurrentStreams(); + final int active = countActiveRemoteInitiated(); + if (active >= inboundLimit) { + try { + stream.localReset(new H2StreamResetException( + H2Error.REFUSED_STREAM, "Inbound stream concurrency limit exceeded")); + } catch (final IOException ignore) { + } + break; + } + stream.channel.idle = false; + } } try { @@ -973,7 +997,7 @@ private void consumeFrame(final RawFrame frame) throws HttpException, IOExceptio updateLastStreamId(promisedStreamId); final H2StreamChannelImpl channel = new H2StreamChannelImpl( - promisedStreamId, false, initInputWinSize, initOutputWinSize); + promisedStreamId, true, initInputWinSize, initOutputWinSize); final H2StreamHandler streamHandler; if (connState.compareTo(ConnectionHandshake.ACTIVE) <= 0) { streamHandler = createRemotelyInitiatedStream(channel, httpProcessor, connMetrics, @@ -985,7 +1009,6 @@ private void consumeFrame(final RawFrame frame) throws HttpException, IOExceptio final H2Stream promisedStream = new H2Stream(channel, streamHandler, true); streamMap.put(promisedStreamId, promisedStream); - try { consumePushPromiseFrame(frame, payload, promisedStream); } catch (final H2StreamResetException ex) { @@ -1218,6 +1241,12 @@ private void produceOutput() throws HttpException, IOException { final Map.Entry entry = it.next(); final H2Stream stream = entry.getValue(); if (!stream.isLocalClosed() && stream.getOutputWindow().get() > 0) { + if (!stream.isRemoteInitiated() && stream.channel.idle) { + final int outboundLimit = remoteConfig.getMaxConcurrentStreams(); + if (countActiveLocalInitiated() >= outboundLimit) { + continue; + } + } stream.produceOutput(); } if (stream.isTerminated()) { @@ -1260,6 +1289,7 @@ private void applyRemoteSettings(final H2Config config) throws H2ConnectionExcep } } } + requestSessionOutput(); } private void applyLocalSettings() throws H2ConnectionException { @@ -1395,6 +1425,29 @@ H2StreamChannelImpl createChannel(final int streamId) { return new H2StreamChannelImpl(streamId, false, initInputWinSize, initOutputWinSize); } + // Count active streams by initiator (directional) for concurrency limits + private int countActiveLocalInitiated() { + int n = 0; + for (final Map.Entry e : streamMap.entrySet()) { + final H2Stream s = e.getValue(); + if (!s.isRemoteInitiated() && !s.isTerminated() && !s.channel.idle) { + n++; + } + } + return n; + } + + private int countActiveRemoteInitiated() { + int n = 0; + for (final Map.Entry e : streamMap.entrySet()) { + final H2Stream s = e.getValue(); + if (s.isRemoteInitiated() && !s.isTerminated() && !s.channel.idle) { + n++; + } + } + return n; + } + class H2StreamChannelImpl implements H2StreamChannel { private final int id; @@ -1557,15 +1610,13 @@ boolean localReset(final int code) throws IOException { if (isLocalReset()) { return false; } - ensureNotClosed(); + // Allow RST_STREAM from any state, including reserved (idle) and even if local end is already closed. localEndStream = true; deadline = System.currentTimeMillis() + LINGER_TIME; - if (!idle) { - final RawFrame resetStream = frameFactory.createResetStream(id, code); - commitFrameInternal(resetStream); - return true; - } - return false; + + final RawFrame resetStream = frameFactory.createResetStream(id, code); + commitFrameInternal(resetStream); + return true; } finally { ioSession.getLock().unlock(); } diff --git a/httpcore5-h2/src/test/java/org/apache/hc/core5/http2/impl/nio/TestAbstractH2StreamMultiplexer.java b/httpcore5-h2/src/test/java/org/apache/hc/core5/http2/impl/nio/TestAbstractH2StreamMultiplexer.java index d0d322dbad..5b68261f0d 100644 --- a/httpcore5-h2/src/test/java/org/apache/hc/core5/http2/impl/nio/TestAbstractH2StreamMultiplexer.java +++ b/httpcore5-h2/src/test/java/org/apache/hc/core5/http2/impl/nio/TestAbstractH2StreamMultiplexer.java @@ -31,6 +31,7 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.concurrent.locks.Lock; @@ -88,11 +89,14 @@ class TestAbstractH2StreamMultiplexer { ArgumentCaptor> headersCaptor; @Captor ArgumentCaptor exceptionCaptor; + @Captor + ArgumentCaptor frameCaptor; @BeforeEach void prepareMocks() { MockitoAnnotations.openMocks(this); Mockito.when(protocolIOSession.getLock()).thenReturn(lock); + Mockito.when(protocolIOSession.poll()).thenReturn(null); } static class H2StreamMultiplexerImpl extends AbstractH2StreamMultiplexer { @@ -666,7 +670,7 @@ void testInputHeaderContinuationFramesMaxLimit() throws Exception { outBuffer.write(continuationFrame3, writableChannel); Assertions.assertThrows(H2ConnectionException.class, () -> - streamMultiplexer.onInput(ByteBuffer.wrap(writableChannel.toByteArray()))); + streamMultiplexer.onInput(ByteBuffer.wrap(writableChannel.toByteArray()))); } @Test @@ -767,5 +771,80 @@ void testStreamRemoteResetNoErrorRemoteAlreadyClosed() throws Exception { Mockito.verify(streamHandler, Mockito.never()).failed(ArgumentMatchers.any()); } -} + @Test + void testInboundPushPromise_thenRefuseOnFirstHeadersWhenOverLimit() throws Exception { + final H2Config h2Config = H2Config.custom() + .setMaxConcurrentStreams(1) // allow only one active remote stream + .build(); + + final AbstractH2StreamMultiplexer mux = new H2StreamMultiplexerImpl( + protocolIOSession, + FRAME_FACTORY, + StreamIdGenerator.ODD, + httpProcessor, + CharCodingConfig.DEFAULT, + h2Config, + h2StreamListener, + () -> streamHandler); + + final WritableByteChannelMock chan = new WritableByteChannelMock(2048); + final FrameOutputBuffer out = new FrameOutputBuffer(16 * 1024); + + // Keep one remote stream (2) active + final ByteArrayBuffer respHdr = new ByteArrayBuffer(64); + final HPackEncoder enc = new HPackEncoder(H2Config.INIT.getHeaderTableSize(), + CharCodingSupport.createEncoder(CharCodingConfig.DEFAULT)); + enc.encodeHeaders(respHdr, + Collections.singletonList(new BasicHeader(":status", "200")), + h2Config.isCompressionEnabled()); + out.write(FRAME_FACTORY.createHeaders( + 2, ByteBuffer.wrap(respHdr.array(), 0, respHdr.length()), true, false), chan); + mux.onInput(ByteBuffer.wrap(chan.toByteArray())); + + // Send PUSH_PROMISE (do NOT verify outputs here) + chan.reset(); + final int promisedId = 4; + final ByteArrayBuffer promiseBuf = new ByteArrayBuffer(128); + promiseBuf.append((byte) (promisedId >> 24)); + promiseBuf.append((byte) (promisedId >> 16)); + promiseBuf.append((byte) (promisedId >> 8)); + promiseBuf.append((byte) promisedId); + enc.encodeHeaders(promiseBuf, Arrays.asList( + new BasicHeader(":method", "GET"), + new BasicHeader(":scheme", "https"), + new BasicHeader(":authority", "example.org"), + new BasicHeader(":path", "/pushed") + ), h2Config.isCompressionEnabled()); + out.write(FRAME_FACTORY.createPushPromise( + 2, ByteBuffer.wrap(promiseBuf.array(), 0, promiseBuf.length()), true), chan); + mux.onInput(ByteBuffer.wrap(chan.toByteArray())); + + chan.reset(); + final ByteArrayBuffer pushedResp = new ByteArrayBuffer(32); + enc.encodeHeaders(pushedResp, + Collections.singletonList(new BasicHeader(":status", "200")), + h2Config.isCompressionEnabled()); + out.write(FRAME_FACTORY.createHeaders( + 4, ByteBuffer.wrap(pushedResp.array(), 0, pushedResp.length()), true, false), chan); + mux.onInput(ByteBuffer.wrap(chan.toByteArray())); + + // Flush any queued outbound frames so listener sees them + mux.onOutput(); + + Mockito.verify(h2StreamListener, Mockito.atLeastOnce()) + .onFrameOutput(ArgumentMatchers.same(mux), ArgumentMatchers.anyInt(), frameCaptor.capture()); + + boolean refused = false; + for (final RawFrame f : frameCaptor.getAllValues()) { + if (f.getType() == FrameType.RST_STREAM.getValue() && f.getStreamId() == promisedId) { + final int code = f.getPayload().duplicate().getInt(); + Assertions.assertEquals(H2Error.REFUSED_STREAM.getCode(), code); + refused = true; + break; + } + } + Assertions.assertTrue(refused, "Expected RST_STREAM(REFUSED_STREAM) on first HEADERS for promised stream 4"); + } + +}